• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "resolver.hpp"
30 
31 #include "platform-posix.h"
32 
33 #include <openthread/logging.h>
34 #include <openthread/message.h>
35 #include <openthread/udp.h>
36 #include <openthread/platform/dns.h>
37 #include <openthread/platform/time.h>
38 
39 #include "common/code_utils.hpp"
40 
41 #include <arpa/inet.h>
42 #include <arpa/nameser.h>
43 #include <cassert>
44 #include <netinet/in.h>
45 #include <sys/select.h>
46 #include <sys/socket.h>
47 #include <unistd.h>
48 
49 #include <fstream>
50 #include <string>
51 
52 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
53 
54 namespace {
55 constexpr char kResolvConfFullPath[] = "/etc/resolv.conf";
56 constexpr char kNameserverItem[]     = "nameserver";
57 } // namespace
58 
59 extern ot::Posix::Resolver gResolver;
60 
61 namespace ot {
62 namespace Posix {
63 
64 const char Resolver::kLogModuleName[] = "Resolver";
65 
Init(void)66 void Resolver::Init(void)
67 {
68     memset(mUpstreamTransaction, 0, sizeof(mUpstreamTransaction));
69     LoadDnsServerListFromConf();
70 }
71 
TryRefreshDnsServerList(void)72 void Resolver::TryRefreshDnsServerList(void)
73 {
74     uint64_t now = otPlatTimeGet();
75 
76     if (now > mUpstreamDnsServerListFreshness + kDnsServerListCacheTimeoutMs ||
77         (mUpstreamDnsServerCount == 0 && now > mUpstreamDnsServerListFreshness + kDnsServerListNullCacheTimeoutMs))
78     {
79         LoadDnsServerListFromConf();
80     }
81 }
82 
LoadDnsServerListFromConf(void)83 void Resolver::LoadDnsServerListFromConf(void)
84 {
85     std::string   line;
86     std::ifstream fp;
87 
88     mUpstreamDnsServerCount = 0;
89 
90     fp.open(kResolvConfFullPath);
91 
92     while (fp.good() && std::getline(fp, line) && mUpstreamDnsServerCount < kMaxUpstreamServerCount)
93     {
94         if (line.find(kNameserverItem, 0) == 0)
95         {
96             in_addr_t addr;
97 
98             if (inet_pton(AF_INET, &line.c_str()[sizeof(kNameserverItem)], &addr) == 1)
99             {
100                 LogInfo("Got nameserver #%d: %s", mUpstreamDnsServerCount, &line.c_str()[sizeof(kNameserverItem)]);
101                 mUpstreamDnsServerList[mUpstreamDnsServerCount] = addr;
102                 mUpstreamDnsServerCount++;
103             }
104         }
105     }
106 
107     if (mUpstreamDnsServerCount == 0)
108     {
109         LogCrit("No domain name servers found in %s, default to 127.0.0.1", kResolvConfFullPath);
110     }
111 
112     mUpstreamDnsServerListFreshness = otPlatTimeGet();
113 }
114 
Query(otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)115 void Resolver::Query(otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
116 {
117     char        packet[kMaxDnsMessageSize];
118     otError     error  = OT_ERROR_NONE;
119     uint16_t    length = otMessageGetLength(aQuery);
120     sockaddr_in serverAddr;
121 
122     Transaction *txn = nullptr;
123 
124     VerifyOrExit(length <= kMaxDnsMessageSize, error = OT_ERROR_NO_BUFS);
125     VerifyOrExit(otMessageRead(aQuery, 0, &packet, sizeof(packet)) == length, error = OT_ERROR_NO_BUFS);
126 
127     txn = AllocateTransaction(aTxn);
128     VerifyOrExit(txn != nullptr, error = OT_ERROR_NO_BUFS);
129 
130     TryRefreshDnsServerList();
131 
132     serverAddr.sin_family = AF_INET;
133     serverAddr.sin_port   = htons(53);
134     for (int i = 0; i < mUpstreamDnsServerCount; i++)
135     {
136         serverAddr.sin_addr.s_addr = mUpstreamDnsServerList[i];
137         VerifyOrExit(
138             sendto(txn->mUdpFd, packet, length, MSG_DONTWAIT, (struct sockaddr *)&serverAddr, sizeof(serverAddr)) > 0,
139             error = OT_ERROR_NO_ROUTE);
140     }
141     LogInfo("Forwarded DNS query %p to %d server(s).", static_cast<void *>(aTxn), mUpstreamDnsServerCount);
142 
143 exit:
144     if (error != OT_ERROR_NONE)
145     {
146         LogCrit("Failed to forward DNS query %p to server: %d", static_cast<void *>(aTxn), error);
147     }
148     return;
149 }
150 
Cancel(otPlatDnsUpstreamQuery * aTxn)151 void Resolver::Cancel(otPlatDnsUpstreamQuery *aTxn)
152 {
153     Transaction *txn = GetTransaction(aTxn);
154 
155     if (txn != nullptr)
156     {
157         CloseTransaction(txn);
158     }
159 
160     otPlatDnsUpstreamQueryDone(gInstance, aTxn, nullptr);
161 }
162 
AllocateTransaction(otPlatDnsUpstreamQuery * aThreadTxn)163 Resolver::Transaction *Resolver::AllocateTransaction(otPlatDnsUpstreamQuery *aThreadTxn)
164 {
165     int          fdOrError = 0;
166     Transaction *ret       = nullptr;
167 
168     for (Transaction &txn : mUpstreamTransaction)
169     {
170         if (txn.mThreadTxn == nullptr)
171         {
172             fdOrError = socket(AF_INET, SOCK_DGRAM, 0);
173             if (fdOrError < 0)
174             {
175                 LogInfo("Failed to create socket for upstream resolver: %d", fdOrError);
176                 break;
177             }
178             ret             = &txn;
179             ret->mUdpFd     = fdOrError;
180             ret->mThreadTxn = aThreadTxn;
181             break;
182         }
183     }
184 
185     return ret;
186 }
187 
ForwardResponse(Transaction * aTxn)188 void Resolver::ForwardResponse(Transaction *aTxn)
189 {
190     char       response[kMaxDnsMessageSize];
191     ssize_t    readSize;
192     otError    error   = OT_ERROR_NONE;
193     otMessage *message = nullptr;
194 
195     VerifyOrExit((readSize = read(aTxn->mUdpFd, response, sizeof(response))) > 0);
196 
197     message = otUdpNewMessage(gInstance, nullptr);
198     VerifyOrExit(message != nullptr, error = OT_ERROR_NO_BUFS);
199     SuccessOrExit(error = otMessageAppend(message, response, readSize));
200 
201     otPlatDnsUpstreamQueryDone(gInstance, aTxn->mThreadTxn, message);
202     message = nullptr;
203 
204 exit:
205     if (readSize < 0)
206     {
207         LogInfo("Failed to read response from upstream resolver socket: %d", errno);
208     }
209     if (error != OT_ERROR_NONE)
210     {
211         LogInfo("Failed to forward upstream DNS response: %s", otThreadErrorToString(error));
212     }
213     if (message != nullptr)
214     {
215         otMessageFree(message);
216     }
217 }
218 
GetTransaction(int aFd)219 Resolver::Transaction *Resolver::GetTransaction(int aFd)
220 {
221     Transaction *ret = nullptr;
222 
223     for (Transaction &txn : mUpstreamTransaction)
224     {
225         if (txn.mThreadTxn != nullptr && txn.mUdpFd == aFd)
226         {
227             ret = &txn;
228             break;
229         }
230     }
231 
232     return ret;
233 }
234 
GetTransaction(otPlatDnsUpstreamQuery * aThreadTxn)235 Resolver::Transaction *Resolver::GetTransaction(otPlatDnsUpstreamQuery *aThreadTxn)
236 {
237     Transaction *ret = nullptr;
238 
239     for (Transaction &txn : mUpstreamTransaction)
240     {
241         if (txn.mThreadTxn == aThreadTxn)
242         {
243             ret = &txn;
244             break;
245         }
246     }
247 
248     return ret;
249 }
250 
CloseTransaction(Transaction * aTxn)251 void Resolver::CloseTransaction(Transaction *aTxn)
252 {
253     if (aTxn->mUdpFd >= 0)
254     {
255         close(aTxn->mUdpFd);
256         aTxn->mUdpFd = -1;
257     }
258     aTxn->mThreadTxn = nullptr;
259 }
260 
UpdateFdSet(otSysMainloopContext & aContext)261 void Resolver::UpdateFdSet(otSysMainloopContext &aContext)
262 {
263     for (Transaction &txn : mUpstreamTransaction)
264     {
265         if (txn.mThreadTxn != nullptr)
266         {
267             FD_SET(txn.mUdpFd, &aContext.mReadFdSet);
268             FD_SET(txn.mUdpFd, &aContext.mErrorFdSet);
269             if (txn.mUdpFd > aContext.mMaxFd)
270             {
271                 aContext.mMaxFd = txn.mUdpFd;
272             }
273         }
274     }
275 }
276 
Process(const otSysMainloopContext & aContext)277 void Resolver::Process(const otSysMainloopContext &aContext)
278 {
279     for (Transaction &txn : mUpstreamTransaction)
280     {
281         if (txn.mThreadTxn != nullptr)
282         {
283             // Note: On Linux, we can only get the error via read, so they should share the same logic.
284             if (FD_ISSET(txn.mUdpFd, &aContext.mErrorFdSet) || FD_ISSET(txn.mUdpFd, &aContext.mReadFdSet))
285             {
286                 ForwardResponse(&txn);
287                 CloseTransaction(&txn);
288             }
289         }
290     }
291 }
292 
293 } // namespace Posix
294 } // namespace ot
295 
otPlatDnsStartUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)296 void otPlatDnsStartUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
297 {
298     OT_UNUSED_VARIABLE(aInstance);
299 
300     gResolver.Query(aTxn, aQuery);
301 }
302 
otPlatDnsCancelUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn)303 void otPlatDnsCancelUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn)
304 {
305     OT_UNUSED_VARIABLE(aInstance);
306 
307     gResolver.Cancel(aTxn);
308 }
309 
310 #endif // OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
311