1 // -*- mode: c++ -*-
2 #ifndef ARES_TEST_H
3 #define ARES_TEST_H
4
5 #include "dns-proto.h"
6 // Include ares internal file for DNS protocol constants
7 #include "nameser.h"
8
9 #include "ares_setup.h"
10 #include "ares.h"
11
12 #include "gtest/gtest.h"
13 #include "gmock/gmock.h"
14
15 #ifdef HAVE_CONFIG_H
16 #include "config.h"
17 #endif
18 #if defined(HAVE_USER_NAMESPACE) && defined(HAVE_UTS_NAMESPACE)
19 #define HAVE_CONTAINER
20 #endif
21
22 #include <functional>
23 #include <list>
24 #include <map>
25 #include <memory>
26 #include <set>
27 #include <string>
28 #include <utility>
29 #include <vector>
30
31 namespace ares {
32
33 typedef unsigned char byte;
34
35 namespace test {
36
37 extern bool verbose;
38 extern int mock_port;
39 extern const std::vector<int> both_families;
40 extern const std::vector<int> ipv4_family;
41 extern const std::vector<int> ipv6_family;
42
43 extern const std::vector<std::pair<int, bool>> both_families_both_modes;
44 extern const std::vector<std::pair<int, bool>> ipv4_family_both_modes;
45 extern const std::vector<std::pair<int, bool>> ipv6_family_both_modes;
46
47 // Which parameters to use in tests
48 extern std::vector<int> families;
49 extern std::vector<std::pair<int, bool>> families_modes;
50
51 // Process all pending work on ares-owned file descriptors, plus
52 // optionally the given set-of-FDs + work function.
53 void ProcessWork(ares_channel channel,
54 std::function<std::set<int>()> get_extrafds,
55 std::function<void(int)> process_extra);
56 std::set<int> NoExtraFDs();
57
58 // Test fixture that ensures library initialization, and allows
59 // memory allocations to be failed.
60 class LibraryTest : public ::testing::Test {
61 public:
LibraryTest()62 LibraryTest() {
63 EXPECT_EQ(ARES_SUCCESS,
64 ares_library_init_mem(ARES_LIB_INIT_ALL,
65 &LibraryTest::amalloc,
66 &LibraryTest::afree,
67 &LibraryTest::arealloc));
68 }
~LibraryTest()69 ~LibraryTest() {
70 ares_library_cleanup();
71 ClearFails();
72 }
73 // Set the n-th malloc call (of any size) from the library to fail.
74 // (nth == 1 means the next call)
75 static void SetAllocFail(int nth);
76 // Set the next malloc call for the given size to fail.
77 static void SetAllocSizeFail(size_t size);
78 // Remove any pending alloc failures.
79 static void ClearFails();
80
81 static void *amalloc(size_t size);
82 static void* arealloc(void *ptr, size_t size);
83 static void afree(void *ptr);
84 private:
85 static bool ShouldAllocFail(size_t size);
86 static unsigned long long fails_;
87 static std::map<size_t, int> size_fails_;
88 };
89
90 // Test fixture that uses a default channel.
91 class DefaultChannelTest : public LibraryTest {
92 public:
DefaultChannelTest()93 DefaultChannelTest() : channel_(nullptr) {
94 EXPECT_EQ(ARES_SUCCESS, ares_init(&channel_));
95 EXPECT_NE(nullptr, channel_);
96 }
97
~DefaultChannelTest()98 ~DefaultChannelTest() {
99 ares_destroy(channel_);
100 channel_ = nullptr;
101 }
102
103 // Process all pending work on ares-owned file descriptors.
104 void Process();
105
106 protected:
107 ares_channel channel_;
108 };
109
110 // Test fixture that uses a default channel with the specified lookup mode.
111 class DefaultChannelModeTest
112 : public LibraryTest,
113 public ::testing::WithParamInterface<std::string> {
114 public:
DefaultChannelModeTest()115 DefaultChannelModeTest() : channel_(nullptr) {
116 struct ares_options opts = {0};
117 opts.lookups = strdup(GetParam().c_str());
118 int optmask = ARES_OPT_LOOKUPS;
119 EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
120 EXPECT_NE(nullptr, channel_);
121 free(opts.lookups);
122 }
123
~DefaultChannelModeTest()124 ~DefaultChannelModeTest() {
125 ares_destroy(channel_);
126 channel_ = nullptr;
127 }
128
129 // Process all pending work on ares-owned file descriptors.
130 void Process();
131
132 protected:
133 ares_channel channel_;
134 };
135
136 // Mock DNS server to allow responses to be scripted by tests.
137 class MockServer {
138 public:
139 MockServer(int family, int port, int tcpport = 0);
140 ~MockServer();
141
142 // Mock method indicating the processing of a particular <name, RRtype>
143 // request.
144 MOCK_METHOD2(OnRequest, void(const std::string& name, int rrtype));
145
146 // Set the reply to be sent next; the query ID field will be overwritten
147 // with the value from the request.
SetReplyData(const std::vector<byte> & reply)148 void SetReplyData(const std::vector<byte>& reply) { reply_ = reply; }
SetReply(const DNSPacket * reply)149 void SetReply(const DNSPacket* reply) { SetReplyData(reply->data()); }
SetReplyQID(int qid)150 void SetReplyQID(int qid) { qid_ = qid; }
151
152 // The set of file descriptors that the server handles.
153 std::set<int> fds() const;
154
155 // Process activity on a file descriptor.
156 void ProcessFD(int fd);
157
158 // Ports the server is responding to
udpport()159 int udpport() const { return udpport_; }
tcpport()160 int tcpport() const { return tcpport_; }
161
162 private:
163 void ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen,
164 int qid, const std::string& name, int rrtype);
165
166 int udpport_;
167 int tcpport_;
168 int udpfd_;
169 int tcpfd_;
170 std::set<int> connfds_;
171 std::vector<byte> reply_;
172 int qid_;
173 };
174
175 // Test fixture that uses a mock DNS server.
176 class MockChannelOptsTest : public LibraryTest {
177 public:
178 MockChannelOptsTest(int count, int family, bool force_tcp, struct ares_options* givenopts, int optmask);
179 ~MockChannelOptsTest();
180
181 // Process all pending work on ares-owned and mock-server-owned file descriptors.
182 void Process();
183
184 protected:
185 // NiceMockServer doesn't complain about uninteresting calls.
186 typedef testing::NiceMock<MockServer> NiceMockServer;
187 typedef std::vector< std::unique_ptr<NiceMockServer> > NiceMockServers;
188
189 std::set<int> fds() const;
190 void ProcessFD(int fd);
191
192 static NiceMockServers BuildServers(int count, int family, int base_port);
193
194 NiceMockServers servers_;
195 // Convenience reference to first server.
196 NiceMockServer& server_;
197 ares_channel channel_;
198 };
199
200 class MockChannelTest
201 : public MockChannelOptsTest,
202 public ::testing::WithParamInterface< std::pair<int, bool> > {
203 public:
MockChannelTest()204 MockChannelTest() : MockChannelOptsTest(1, GetParam().first, GetParam().second, nullptr, 0) {}
205 };
206
207 class MockUDPChannelTest
208 : public MockChannelOptsTest,
209 public ::testing::WithParamInterface<int> {
210 public:
MockUDPChannelTest()211 MockUDPChannelTest() : MockChannelOptsTest(1, GetParam(), false, nullptr, 0) {}
212 };
213
214 class MockTCPChannelTest
215 : public MockChannelOptsTest,
216 public ::testing::WithParamInterface<int> {
217 public:
MockTCPChannelTest()218 MockTCPChannelTest() : MockChannelOptsTest(1, GetParam(), true, nullptr, 0) {}
219 };
220
221 // gMock action to set the reply for a mock server.
ACTION_P2(SetReplyData,mockserver,data)222 ACTION_P2(SetReplyData, mockserver, data) {
223 mockserver->SetReplyData(data);
224 }
ACTION_P2(SetReply,mockserver,reply)225 ACTION_P2(SetReply, mockserver, reply) {
226 mockserver->SetReply(reply);
227 }
ACTION_P2(SetReplyQID,mockserver,qid)228 ACTION_P2(SetReplyQID, mockserver, qid) {
229 mockserver->SetReplyQID(qid);
230 }
231 // gMock action to cancel a channel.
ACTION_P2(CancelChannel,mockserver,channel)232 ACTION_P2(CancelChannel, mockserver, channel) {
233 ares_cancel(channel);
234 }
235
236 // C++ wrapper for struct hostent.
237 struct HostEnt {
HostEntHostEnt238 HostEnt() : addrtype_(-1) {}
239 HostEnt(const struct hostent* hostent);
240 std::string name_;
241 std::vector<std::string> aliases_;
242 int addrtype_; // AF_INET or AF_INET6
243 std::vector<std::string> addrs_;
244 };
245 std::ostream& operator<<(std::ostream& os, const HostEnt& result);
246
247 // Structure that describes the result of an ares_host_callback invocation.
248 struct HostResult {
249 // Whether the callback has been invoked.
250 bool done_;
251 // Explicitly provided result information.
252 int status_;
253 int timeouts_;
254 // Contents of the hostent structure, if provided.
255 HostEnt host_;
256 };
257 std::ostream& operator<<(std::ostream& os, const HostResult& result);
258
259 // Structure that describes the result of an ares_callback invocation.
260 struct SearchResult {
261 // Whether the callback has been invoked.
262 bool done_;
263 // Explicitly provided result information.
264 int status_;
265 int timeouts_;
266 std::vector<byte> data_;
267 };
268 std::ostream& operator<<(std::ostream& os, const SearchResult& result);
269
270 // Structure that describes the result of an ares_nameinfo_callback invocation.
271 struct NameInfoResult {
272 // Whether the callback has been invoked.
273 bool done_;
274 // Explicitly provided result information.
275 int status_;
276 int timeouts_;
277 std::string node_;
278 std::string service_;
279 };
280 std::ostream& operator<<(std::ostream& os, const NameInfoResult& result);
281
282 struct AddrInfoDeleter {
operatorAddrInfoDeleter283 void operator() (ares_addrinfo *ptr) {
284 if (ptr) ares_freeaddrinfo(ptr);
285 }
286 };
287
288 // C++ wrapper for struct ares_addrinfo.
289 using AddrInfo = std::unique_ptr<ares_addrinfo, AddrInfoDeleter>;
290
291 std::ostream& operator<<(std::ostream& os, const AddrInfo& result);
292
293 // Structure that describes the result of an ares_addrinfo_callback invocation.
294 struct AddrInfoResult {
AddrInfoResultAddrInfoResult295 AddrInfoResult() : done_(false), status_(-1), timeouts_(0) {}
296 // Whether the callback has been invoked.
297 bool done_;
298 // Explicitly provided result information.
299 int status_;
300 int timeouts_;
301 // Contents of the ares_addrinfo structure, if provided.
302 AddrInfo ai_;
303 };
304 std::ostream& operator<<(std::ostream& os, const AddrInfoResult& result);
305
306 // Standard implementation of ares callbacks that fill out the corresponding
307 // structures.
308 void HostCallback(void *data, int status, int timeouts,
309 struct hostent *hostent);
310 void SearchCallback(void *data, int status, int timeouts,
311 unsigned char *abuf, int alen);
312 void NameInfoCallback(void *data, int status, int timeouts,
313 char *node, char *service);
314 void AddrInfoCallback(void *data, int status, int timeouts,
315 struct ares_addrinfo *res);
316
317 // Retrieve the name servers used by a channel.
318 std::vector<std::string> GetNameServers(ares_channel channel);
319
320
321 // RAII class to temporarily create a directory of a given name.
322 class TransientDir {
323 public:
324 TransientDir(const std::string& dirname);
325 ~TransientDir();
326
327 private:
328 std::string dirname_;
329 };
330
331 // C++ wrapper around tempnam()
332 std::string TempNam(const char *dir, const char *prefix);
333
334 // RAII class to temporarily create file of a given name and contents.
335 class TransientFile {
336 public:
337 TransientFile(const std::string &filename, const std::string &contents);
338 ~TransientFile();
339
340 protected:
341 std::string filename_;
342 };
343
344 // RAII class for a temporary file with the given contents.
345 class TempFile : public TransientFile {
346 public:
347 TempFile(const std::string& contents);
filename()348 const char* filename() const { return filename_.c_str(); }
349 };
350
351 #ifdef _WIN32
352 extern "C" {
353
setenv(const char * name,const char * value,int overwrite)354 static int setenv(const char *name, const char *value, int overwrite)
355 {
356 char *buffer;
357 size_t buf_size;
358
359 if (name == NULL)
360 return -1;
361
362 if (value == NULL)
363 value = ""; /* For unset */
364
365 if (!overwrite && getenv(name) != NULL) {
366 return -1;
367 }
368
369 buf_size = strlen(name) + strlen(value) + 1 /* = */ + 1 /* NULL */;
370 buffer = (char *)malloc(buf_size);
371 _snprintf(buffer, buf_size, "%s=%s", name, value);
372 _putenv(buffer);
373 free(buffer);
374 return 0;
375 }
376
unsetenv(const char * name)377 static int unsetenv(const char *name)
378 {
379 return setenv(name, NULL, 1);
380 }
381
382 } /* extern "C" */
383 #endif
384
385 // RAII class for a temporary environment variable value.
386 class EnvValue {
387 public:
EnvValue(const char * name,const char * value)388 EnvValue(const char *name, const char *value) : name_(name), restore_(false) {
389 char *original = getenv(name);
390 if (original) {
391 restore_ = true;
392 original_ = original;
393 }
394 setenv(name_.c_str(), value, 1);
395 }
~EnvValue()396 ~EnvValue() {
397 if (restore_) {
398 setenv(name_.c_str(), original_.c_str(), 1);
399 } else {
400 unsetenv(name_.c_str());
401 }
402 }
403 private:
404 std::string name_;
405 bool restore_;
406 std::string original_;
407 };
408
409
410 #ifdef HAVE_CONTAINER
411 // Linux-specific functionality for running code in a container, implemented
412 // in ares-test-ns.cc
413 typedef std::function<int(void)> VoidToIntFn;
414 typedef std::vector<std::pair<std::string, std::string>> NameContentList;
415
416 class ContainerFilesystem {
417 public:
418 ContainerFilesystem(NameContentList files, const std::string& mountpt);
419 ~ContainerFilesystem();
root()420 std::string root() const { return rootdir_; };
mountpt()421 std::string mountpt() const { return mountpt_; };
422 private:
423 void EnsureDirExists(const std::string& dir);
424 std::string rootdir_;
425 std::string mountpt_;
426 std::list<std::string> dirs_;
427 std::vector<std::unique_ptr<TransientFile>> files_;
428 };
429
430 int RunInContainer(ContainerFilesystem* fs, const std::string& hostname,
431 const std::string& domainname, VoidToIntFn fn);
432
433 #define ICLASS_NAME(casename, testname) Contained##casename##_##testname
434 #define CONTAINED_TEST_F(casename, testname, hostname, domainname, files) \
435 class ICLASS_NAME(casename, testname) : public casename { \
436 public: \
437 ICLASS_NAME(casename, testname)() {} \
438 static int InnerTestBody(); \
439 }; \
440 TEST_F(ICLASS_NAME(casename, testname), _) { \
441 ContainerFilesystem chroot(files, ".."); \
442 VoidToIntFn fn(ICLASS_NAME(casename, testname)::InnerTestBody); \
443 EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn)); \
444 } \
445 int ICLASS_NAME(casename, testname)::InnerTestBody()
446
447 #endif
448
449 /* Assigns virtual IO functions to a channel. These functions simply call
450 * the actual system functions.
451 */
452 class VirtualizeIO {
453 public:
454 VirtualizeIO(ares_channel);
455 ~VirtualizeIO();
456
457 static const ares_socket_functions default_functions;
458 private:
459 ares_channel channel_;
460 };
461
462 /*
463 * Slightly white-box macro to generate two runs for a given test case:
464 * One with no modifications, and one with all IO functions set to use
465 * the virtual io structure.
466 * Since no magic socket setup or anything is done in the latter case
467 * this should probably only be used for test with very vanilla IO
468 * requirements.
469 */
470 #define VCLASS_NAME(casename, testname) Virt##casename##_##testname
471 #define VIRT_NONVIRT_TEST_F(casename, testname) \
472 class VCLASS_NAME(casename, testname) : public casename { \
473 public: \
474 VCLASS_NAME(casename, testname)() {} \
475 void InnerTestBody(); \
476 }; \
477 GTEST_TEST_(casename, testname, VCLASS_NAME(casename, testname), \
478 ::testing::internal::GetTypeId<casename>()) { \
479 InnerTestBody(); \
480 } \
481 GTEST_TEST_(casename, testname##_virtualized, \
482 VCLASS_NAME(casename, testname), \
483 ::testing::internal::GetTypeId<casename>()) { \
484 VirtualizeIO vio(channel_); \
485 InnerTestBody(); \
486 } \
487 void VCLASS_NAME(casename, testname)::InnerTestBody()
488
489 } // namespace test
490 } // namespace ares
491
492 #endif
493