1 /* MIT License
2 *
3 * Copyright (c) The c-ares project and its contributors
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a copy
6 * of this software and associated documentation files (the "Software"), to deal
7 * in the Software without restriction, including without limitation the rights
8 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 * copies of the Software, and to permit persons to whom the Software is
10 * furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 *
24 * SPDX-License-Identifier: MIT
25 */
26 // -*- mode: c++ -*-
27 #ifndef ARES_TEST_H
28 #define ARES_TEST_H
29
30 #include "ares_setup.h"
31 #include "dns-proto.h"
32 // Include ares internal file for DNS protocol constants
33 #include "ares_nameser.h"
34
35 #include "gtest/gtest.h"
36 #include "gmock/gmock.h"
37
38 #if defined(HAVE_USER_NAMESPACE) && defined(HAVE_UTS_NAMESPACE)
39 # define HAVE_CONTAINER
40 #endif
41
42 #include <functional>
43 #include <list>
44 #include <map>
45 #include <memory>
46 #include <set>
47 #include <string>
48 #include <mutex>
49 #include <thread>
50 #include <utility>
51 #include <vector>
52 #include <chrono>
53
54 #if defined(HAVE_CLOSESOCKET)
55 # define sclose(x) closesocket(x)
56 #elif defined(HAVE_CLOSESOCKET_CAMEL)
57 # define sclose(x) CloseSocket(x)
58 #elif defined(HAVE_CLOSE_S)
59 # define sclose(x) close_s(x)
60 #else
61 # define sclose(x) close(x)
62 #endif
63
64 #ifndef HAVE_WRITEV
65 extern "C" {
66 /* Structure for scatter/gather I/O. */
67 struct iovec {
68 void *iov_base; /* Pointer to data. */
69 size_t iov_len; /* Length of data. */
70 };
71 };
72 #endif
73
74 namespace ares {
75
76 typedef unsigned char byte;
77
78 namespace test {
79
80 extern bool verbose;
81 extern unsigned short mock_port;
82 extern const std::vector<int> both_families;
83 extern const std::vector<int> ipv4_family;
84 extern const std::vector<int> ipv6_family;
85
86 extern const std::vector<std::pair<int, bool>> both_families_both_modes;
87 extern const std::vector<std::pair<int, bool>> ipv4_family_both_modes;
88 extern const std::vector<std::pair<int, bool>> ipv6_family_both_modes;
89
90 extern const std::vector<std::tuple<ares_evsys_t, int, bool>>
91 all_evsys_ipv4_family_both_modes;
92 extern const std::vector<std::tuple<ares_evsys_t, int, bool>>
93 all_evsys_ipv6_family_both_modes;
94 extern const std::vector<std::tuple<ares_evsys_t, int, bool>>
95 all_evsys_both_families_both_modes;
96
97 extern const std::vector<std::tuple<ares_evsys_t, int>> all_evsys_ipv4_family;
98 extern const std::vector<std::tuple<ares_evsys_t, int>> all_evsys_ipv6_family;
99 extern const std::vector<std::tuple<ares_evsys_t, int>> all_evsys_both_families;
100
101 // Which parameters to use in tests
102 extern std::vector<int> families;
103 extern std::vector<std::tuple<ares_evsys_t, int>> evsys_families;
104 extern std::vector<std::pair<int, bool>> families_modes;
105 extern std::vector<std::tuple<ares_evsys_t, int, bool>> evsys_families_modes;
106
107 // Hopefully a more accurate sleep than sleep_for()
108 void ares_sleep_time(unsigned int ms);
109
110 // Process all pending work on ares-owned file descriptors, plus
111 // optionally the given set-of-FDs + work function.
112 void ProcessWork(ares_channel_t *channel,
113 std::function<std::set<ares_socket_t>()> get_extrafds,
114 std::function<void(ares_socket_t)> process_extra,
115 unsigned int cancel_ms = 0);
116 std::set<ares_socket_t> NoExtraFDs();
117
118 const char *af_tostr(int af);
119 const char *mode_tostr(bool mode);
120 std::string
121 PrintFamilyMode(const testing::TestParamInfo<std::pair<int, bool>> &info);
122 std::string PrintFamily(const testing::TestParamInfo<int> &info);
123
124 // Test fixture that ensures library initialization, and allows
125 // memory allocations to be failed.
126 class LibraryTest : public ::testing::Test {
127 public:
LibraryTest()128 LibraryTest()
129 {
130 EXPECT_EQ(ARES_SUCCESS, ares_library_init_mem(
131 ARES_LIB_INIT_ALL, &LibraryTest::amalloc,
132 &LibraryTest::afree, &LibraryTest::arealloc));
133 }
134
~LibraryTest()135 ~LibraryTest()
136 {
137 ares_library_cleanup();
138 ClearFails();
139 }
140
141 // Set the n-th malloc call (of any size) from the library to fail.
142 // (nth == 1 means the next call)
143 static void SetAllocFail(int nth);
144 // Set the next malloc call for the given size to fail.
145 static void SetAllocSizeFail(size_t size);
146 // Remove any pending alloc failures.
147 static void ClearFails();
148
149 static void *amalloc(size_t size);
150 static void *arealloc(void *ptr, size_t size);
151 static void afree(void *ptr);
152
153 static void SetFailSend(void);
154 static ares_ssize_t ares_sendv_fail(ares_socket_t socket, const struct iovec *vec, int len,
155 void *user_data);
156
157
158 private:
159 static bool ShouldAllocFail(size_t size);
160 static unsigned long long fails_;
161 static std::map<size_t, int> size_fails_;
162 static std::mutex lock_;
163 static bool failsend_;
164 };
165
166 // Test fixture that uses a default channel.
167 class DefaultChannelTest : public LibraryTest {
168 public:
DefaultChannelTest()169 DefaultChannelTest() : channel_(nullptr)
170 {
171 /* Enable query cache for live tests */
172 struct ares_options opts;
173 memset(&opts, 0, sizeof(opts));
174 opts.qcache_max_ttl = 300;
175 int optmask = ARES_OPT_QUERY_CACHE;
176 EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
177 EXPECT_NE(nullptr, channel_);
178 }
179
~DefaultChannelTest()180 ~DefaultChannelTest()
181 {
182 ares_destroy(channel_);
183 channel_ = nullptr;
184 }
185
186 // Process all pending work on ares-owned file descriptors.
187 void Process(unsigned int cancel_ms = 0);
188
189 protected:
190 ares_channel_t *channel_;
191 };
192
193 // Test fixture that uses a file-only channel.
194 class FileChannelTest : public LibraryTest {
195 public:
FileChannelTest()196 FileChannelTest() : channel_(nullptr)
197 {
198 struct ares_options opts;
199 memset(&opts, 0, sizeof(opts));
200 opts.lookups = strdup("f");
201 int optmask = ARES_OPT_LOOKUPS;
202 EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
203 EXPECT_NE(nullptr, channel_);
204 free(opts.lookups);
205 }
206
~FileChannelTest()207 ~FileChannelTest()
208 {
209 ares_destroy(channel_);
210 channel_ = nullptr;
211 }
212
213 // Process all pending work on ares-owned file descriptors.
214 void Process(unsigned int cancel_ms = 0);
215
216 protected:
217 ares_channel_t *channel_;
218 };
219
220 // Test fixture that uses a default channel with the specified lookup mode.
221 class DefaultChannelModeTest
222 : public LibraryTest,
223 public ::testing::WithParamInterface<std::string> {
224 public:
DefaultChannelModeTest()225 DefaultChannelModeTest() : channel_(nullptr)
226 {
227 struct ares_options opts;
228 memset(&opts, 0, sizeof(opts));
229 opts.lookups = strdup(GetParam().c_str());
230 int optmask = ARES_OPT_LOOKUPS;
231 EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
232 EXPECT_NE(nullptr, channel_);
233 free(opts.lookups);
234 }
235
~DefaultChannelModeTest()236 ~DefaultChannelModeTest()
237 {
238 ares_destroy(channel_);
239 channel_ = nullptr;
240 }
241
242 // Process all pending work on ares-owned file descriptors.
243 void Process(unsigned int cancel_ms = 0);
244
245 protected:
246 ares_channel_t *channel_;
247 };
248
249 // Mock DNS server to allow responses to be scripted by tests.
250 class MockServer {
251 public:
252 MockServer(int family, unsigned short port);
253 ~MockServer();
254
255 // Mock method indicating the processing of a particular <name, RRtype>
256 // request.
257 MOCK_METHOD2(OnRequest, void(const std::string &name, int rrtype));
258
259 // Set the reply to be sent next; the query ID field will be overwritten
260 // with the value from the request.
SetReplyData(const std::vector<byte> & reply)261 void SetReplyData(const std::vector<byte> &reply)
262 {
263 exact_reply_ = reply;
264 reply_ = nullptr;
265 }
266
SetReply(const DNSPacket * reply)267 void SetReply(const DNSPacket *reply)
268 {
269 reply_ = reply;
270 exact_reply_.clear();
271 }
272
273 // Set the reply to be sent next as well as the request (in string form) that
274 // the server should expect to receive; the query ID field in the reply will
275 // be overwritten with the value from the request.
SetReplyExpRequest(const DNSPacket * reply,const std::string & request)276 void SetReplyExpRequest(const DNSPacket *reply, const std::string &request)
277 {
278 expected_request_ = request;
279 reply_ = reply;
280 }
281
SetReplyQID(int qid)282 void SetReplyQID(int qid)
283 {
284 qid_ = qid;
285 }
286
Disconnect()287 void Disconnect()
288 {
289 reply_ = nullptr;
290 exact_reply_.clear();
291 for (ares_socket_t fd : connfds_) {
292 sclose(fd);
293 }
294 connfds_.clear();
295 free(tcp_data_);
296 tcp_data_ = NULL;
297 tcp_data_len_ = 0;
298 }
299
300 // The set of file descriptors that the server handles.
301 std::set<ares_socket_t> fds() const;
302
303 // Process activity on a file descriptor.
304 void ProcessFD(ares_socket_t fd);
305
306 // Ports the server is responding to
udpport()307 unsigned short udpport() const
308 {
309 return udpport_;
310 }
311
tcpport()312 unsigned short tcpport() const
313 {
314 return tcpport_;
315 }
316
317 private:
318 void ProcessRequest(ares_socket_t fd, struct sockaddr_storage *addr,
319 ares_socklen_t addrlen, const std::vector<byte> &req,
320 const std::string &reqstr, int qid, const char *name,
321 int rrtype);
322 void ProcessPacket(ares_socket_t fd, struct sockaddr_storage *addr,
323 ares_socklen_t addrlen, byte *data, int len);
324 unsigned short udpport_;
325 unsigned short tcpport_;
326 ares_socket_t udpfd_;
327 ares_socket_t tcpfd_;
328 std::set<ares_socket_t> connfds_;
329 std::vector<byte> exact_reply_;
330 const DNSPacket *reply_;
331 std::string expected_request_;
332 int qid_;
333 unsigned char *tcp_data_;
334 size_t tcp_data_len_;
335 };
336
337 // Test fixture that uses a mock DNS server.
338 class MockChannelOptsTest : public LibraryTest {
339 public:
340 MockChannelOptsTest(int count, int family, bool force_tcp,
341 bool honor_sysconfig, struct ares_options *givenopts,
342 int optmask);
343 ~MockChannelOptsTest();
344
345 // Process all pending work on ares-owned and mock-server-owned file
346 // descriptors.
347 void ProcessAltChannel(ares_channel_t *chan, unsigned int cancel_ms = 0);
348 void Process(unsigned int cancel_ms = 0);
349
350 protected:
351 // NiceMockServer doesn't complain about uninteresting calls.
352 typedef testing::NiceMock<MockServer> NiceMockServer;
353 typedef std::vector<std::unique_ptr<NiceMockServer>> NiceMockServers;
354
355 std::set<ares_socket_t> fds() const;
356 void ProcessFD(ares_socket_t fd);
357
358 static NiceMockServers BuildServers(int count, int family,
359 unsigned short base_port);
360
361 NiceMockServers servers_;
362 // Convenience reference to first server.
363 NiceMockServer &server_;
364 ares_channel_t *channel_;
365 };
366
367 class MockChannelTest
368 : public MockChannelOptsTest,
369 public ::testing::WithParamInterface<std::pair<int, bool>> {
370 public:
MockChannelTest()371 MockChannelTest()
372 : MockChannelOptsTest(1, GetParam().first, GetParam().second, false,
373 nullptr, 0)
374 {
375 }
376 };
377
378 class MockUDPChannelTest : public MockChannelOptsTest,
379 public ::testing::WithParamInterface<int> {
380 public:
MockUDPChannelTest()381 MockUDPChannelTest()
382 : MockChannelOptsTest(1, GetParam(), false, false, nullptr, 0)
383 {
384 }
385 };
386
387 class MockTCPChannelTest : public MockChannelOptsTest,
388 public ::testing::WithParamInterface<int> {
389 public:
MockTCPChannelTest()390 MockTCPChannelTest()
391 : MockChannelOptsTest(1, GetParam(), true, false, nullptr, 0)
392 {
393 }
394 };
395
396 class MockEventThreadOptsTest : public MockChannelOptsTest {
397 public:
MockEventThreadOptsTest(int count,ares_evsys_t evsys,int family,bool force_tcp,struct ares_options * givenopts,int optmask)398 MockEventThreadOptsTest(int count, ares_evsys_t evsys, int family,
399 bool force_tcp, struct ares_options *givenopts,
400 int optmask)
401 : MockChannelOptsTest(count, family, force_tcp, false,
402 FillOptionsET(&evopts_, givenopts, evsys),
403 optmask | ARES_OPT_EVENT_THREAD)
404 {
405 }
406
~MockEventThreadOptsTest()407 ~MockEventThreadOptsTest()
408 {
409 }
410
FillOptionsET(struct ares_options * opts,struct ares_options * givenopts,ares_evsys_t evsys)411 static struct ares_options *FillOptionsET(struct ares_options *opts,
412 struct ares_options *givenopts,
413 ares_evsys_t evsys)
414 {
415 if (givenopts) {
416 memcpy(opts, givenopts, sizeof(*opts));
417 } else {
418 memset(opts, 0, sizeof(*opts));
419 }
420 opts->evsys = evsys;
421 return opts;
422 }
423
424 void Process(unsigned int cancel_ms = 0);
425
426 private:
427 struct ares_options evopts_;
428 };
429
430 class MockEventThreadTest
431 : public MockEventThreadOptsTest,
432 public ::testing::WithParamInterface<std::tuple<ares_evsys_t, int, bool>> {
433 public:
MockEventThreadTest()434 MockEventThreadTest()
435 : MockEventThreadOptsTest(1, std::get<0>(GetParam()),
436 std::get<1>(GetParam()), std::get<2>(GetParam()),
437 nullptr, 0)
438 {
439 }
440 };
441
442 class MockUDPEventThreadTest
443 : public MockEventThreadOptsTest,
444 public ::testing::WithParamInterface<std::tuple<ares_evsys_t, int>> {
445 public:
MockUDPEventThreadTest()446 MockUDPEventThreadTest()
447 : MockEventThreadOptsTest(1, std::get<0>(GetParam()),
448 std::get<1>(GetParam()), false, nullptr, 0)
449 {
450 }
451 };
452
453 class MockTCPEventThreadTest
454 : public MockEventThreadOptsTest,
455 public ::testing::WithParamInterface<std::tuple<ares_evsys_t, int>> {
456 public:
MockTCPEventThreadTest()457 MockTCPEventThreadTest()
458 : MockEventThreadOptsTest(1, std::get<0>(GetParam()),
459 std::get<1>(GetParam()), true, nullptr, 0)
460 {
461 }
462 };
463
464 // gMock action to set the reply for a mock server.
ACTION_P2(SetReplyData,mockserver,data)465 ACTION_P2(SetReplyData, mockserver, data)
466 {
467 mockserver->SetReplyData(data);
468 }
469
ACTION_P2(SetReplyAndFailSend,mockserver,reply)470 ACTION_P2(SetReplyAndFailSend, mockserver, reply)
471 {
472 mockserver->SetReply(reply);
473 LibraryTest::SetFailSend();
474 }
475
ACTION_P2(SetReply,mockserver,reply)476 ACTION_P2(SetReply, mockserver, reply)
477 {
478 mockserver->SetReply(reply);
479 }
480
481 // gMock action to set the reply for a mock server, as well as the request (in
482 // string form) that the server should expect to receive.
ACTION_P3(SetReplyExpRequest,mockserver,reply,request)483 ACTION_P3(SetReplyExpRequest, mockserver, reply, request)
484 {
485 mockserver->SetReplyExpRequest(reply, request);
486 }
487
ACTION_P2(SetReplyQID,mockserver,qid)488 ACTION_P2(SetReplyQID, mockserver, qid)
489 {
490 mockserver->SetReplyQID(qid);
491 }
492
493 // gMock action to cancel a channel.
ACTION_P2(CancelChannel,mockserver,channel)494 ACTION_P2(CancelChannel, mockserver, channel)
495 {
496 ares_cancel(channel);
497 }
498
499 // gMock action to disconnect all connections.
ACTION_P(Disconnect,mockserver)500 ACTION_P(Disconnect, mockserver)
501 {
502 mockserver->Disconnect();
503 }
504
505 // C++ wrapper for struct hostent.
506 struct HostEnt {
HostEntHostEnt507 HostEnt() : addrtype_(-1)
508 {
509 }
510
511 HostEnt(const struct hostent *hostent);
512 std::string name_;
513 std::vector<std::string> aliases_;
514 int addrtype_; // AF_INET or AF_INET6
515 std::vector<std::string> addrs_;
516 };
517
518 std::ostream &operator<<(std::ostream &os, const HostEnt &result);
519
520 // Structure that describes the result of an ares_host_callback invocation.
521 struct HostResult {
HostResultHostResult522 HostResult() : done_(false), status_(0), timeouts_(0)
523 {
524 }
525
526 // Whether the callback has been invoked.
527 bool done_;
528 // Explicitly provided result information.
529 int status_;
530 int timeouts_;
531 // Contents of the hostent structure, if provided.
532 HostEnt host_;
533 };
534
535 std::ostream &operator<<(std::ostream &os, const HostResult &result);
536
537 // C++ wrapper for ares_dns_record_t.
538 struct AresDnsRecord {
~AresDnsRecordAresDnsRecord539 ~AresDnsRecord()
540 {
541 ares_dns_record_destroy(dnsrec_);
542 dnsrec_ = NULL;
543 }
544
AresDnsRecordAresDnsRecord545 AresDnsRecord() : dnsrec_(NULL)
546 {
547 }
548
SetDnsRecordAresDnsRecord549 void SetDnsRecord(const ares_dns_record_t *dnsrec)
550 {
551 if (dnsrec_ != NULL) {
552 ares_dns_record_destroy(dnsrec_);
553 }
554 if (dnsrec == NULL) {
555 return;
556 }
557 dnsrec_ = ares_dns_record_duplicate(dnsrec);
558 }
559
560 ares_dns_record_t *dnsrec_ = NULL;
561 };
562
563 std::ostream &operator<<(std::ostream &os, const AresDnsRecord &result);
564
565 // Structure that describes the result of an ares_host_callback invocation.
566 struct QueryResult {
QueryResultQueryResult567 QueryResult() : done_(false), status_(ARES_SUCCESS), timeouts_(0)
568 {
569 }
570
571 // Whether the callback has been invoked.
572 bool done_;
573 // Explicitly provided result information.
574 ares_status_t status_;
575 size_t timeouts_;
576 // Contents of the ares_dns_record_t structure if provided
577 AresDnsRecord dnsrec_;
578 };
579
580 std::ostream &operator<<(std::ostream &os, const QueryResult &result);
581
582 // Structure that describes the result of an ares_callback invocation.
583 struct SearchResult {
584 // Whether the callback has been invoked.
585 bool done_;
586 // Explicitly provided result information.
587 int status_;
588 int timeouts_;
589 std::vector<byte> data_;
590 };
591
592 std::ostream &operator<<(std::ostream &os, const SearchResult &result);
593
594 // Structure that describes the result of an ares_nameinfo_callback invocation.
595 struct NameInfoResult {
596 // Whether the callback has been invoked.
597 bool done_;
598 // Explicitly provided result information.
599 int status_;
600 int timeouts_;
601 std::string node_;
602 std::string service_;
603 };
604
605 std::ostream &operator<<(std::ostream &os, const NameInfoResult &result);
606
607 struct AddrInfoDeleter {
operatorAddrInfoDeleter608 void operator()(ares_addrinfo *ptr)
609 {
610 if (ptr) {
611 ares_freeaddrinfo(ptr);
612 }
613 }
614 };
615
616 // C++ wrapper for struct ares_addrinfo.
617 using AddrInfo = std::unique_ptr<ares_addrinfo, AddrInfoDeleter>;
618
619 std::ostream &operator<<(std::ostream &os, const AddrInfo &result);
620
621 // Structure that describes the result of an ares_addrinfo_callback invocation.
622 struct AddrInfoResult {
AddrInfoResultAddrInfoResult623 AddrInfoResult() : done_(false), status_(-1), timeouts_(0)
624 {
625 }
626
627 // Whether the callback has been invoked.
628 bool done_;
629 // Explicitly provided result information.
630 int status_;
631 int timeouts_;
632 // Contents of the ares_addrinfo structure, if provided.
633 AddrInfo ai_;
634 };
635
636 std::ostream &operator<<(std::ostream &os, const AddrInfoResult &result);
637
638 // Standard implementation of ares callbacks that fill out the corresponding
639 // structures.
640 void HostCallback(void *data, int status, int timeouts,
641 struct hostent *hostent);
642 void QueryCallback(void *data, ares_status_t status, size_t timeouts,
643 const ares_dns_record_t *dnsrec);
644 void SearchCallback(void *data, int status, int timeouts, unsigned char *abuf,
645 int alen);
646 void SearchCallbackDnsRec(void *data, ares_status_t status, size_t timeouts,
647 const ares_dns_record_t *dnsrec);
648 void NameInfoCallback(void *data, int status, int timeouts, char *node,
649 char *service);
650 void AddrInfoCallback(void *data, int status, int timeouts,
651 struct ares_addrinfo *res);
652
653 // Retrieve the name servers used by a channel.
654 std::string GetNameServers(ares_channel_t *channel);
655
656 // RAII class to temporarily create a directory of a given name.
657 class TransientDir {
658 public:
659 TransientDir(const std::string &dirname);
660 ~TransientDir();
661
662 private:
663 std::string dirname_;
664 };
665
666 // C++ wrapper around tempnam()
667 std::string TempNam(const char *dir, const char *prefix);
668
669 // RAII class to temporarily create file of a given name and contents.
670 class TransientFile {
671 public:
672 TransientFile(const std::string &filename, const std::string &contents);
673 ~TransientFile();
674
675 protected:
676 std::string filename_;
677 };
678
679 // RAII class for a temporary file with the given contents.
680 class TempFile : public TransientFile {
681 public:
682 TempFile(const std::string &contents);
683
filename()684 const char *filename() const
685 {
686 return filename_.c_str();
687 }
688 };
689
690 #ifdef _WIN32
691 extern "C" {
692
setenv(const char * name,const char * value,int overwrite)693 static int setenv(const char *name, const char *value, int overwrite)
694 {
695 char *buffer;
696 size_t buf_size;
697
698 if (name == NULL) {
699 return -1;
700 }
701
702 if (value == NULL) {
703 value = ""; /* For unset */
704 }
705
706 if (!overwrite && getenv(name) != NULL) {
707 return -1;
708 }
709
710 buf_size = strlen(name) + strlen(value) + 1 /* = */ + 1 /* NULL */;
711 buffer = (char *)malloc(buf_size);
712 _snprintf(buffer, buf_size, "%s=%s", name, value);
713 _putenv(buffer);
714 free(buffer);
715 return 0;
716 }
717
unsetenv(const char * name)718 static int unsetenv(const char *name)
719 {
720 return setenv(name, NULL, 1);
721 }
722
723 } /* extern "C" */
724 #endif
725
726 // RAII class for a temporary environment variable value.
727 class EnvValue {
728 public:
EnvValue(const char * name,const char * value)729 EnvValue(const char *name, const char *value) : name_(name), restore_(false)
730 {
731 char *original = getenv(name);
732 if (original) {
733 restore_ = true;
734 original_ = original;
735 }
736 setenv(name_.c_str(), value, 1);
737 }
738
~EnvValue()739 ~EnvValue()
740 {
741 if (restore_) {
742 setenv(name_.c_str(), original_.c_str(), 1);
743 } else {
744 unsetenv(name_.c_str());
745 }
746 }
747
748 private:
749 std::string name_;
750 bool restore_;
751 std::string original_;
752 };
753
754
755 #ifdef HAVE_CONTAINER
756 // Linux-specific functionality for running code in a container, implemented
757 // in ares-test-ns.cc
758 typedef std::function<int(void)> VoidToIntFn;
759 typedef std::vector<std::pair<std::string, std::string>> NameContentList;
760
761 class ContainerFilesystem {
762 public:
763 ContainerFilesystem(NameContentList files, const std::string &mountpt);
764 ~ContainerFilesystem();
765
root()766 std::string root() const
767 {
768 return rootdir_;
769 }
770
mountpt()771 std::string mountpt() const
772 {
773 return mountpt_;
774 }
775
776 private:
777 void EnsureDirExists(const std::string &dir);
778 std::string rootdir_;
779 std::string mountpt_;
780 std::list<std::string> dirs_;
781 std::vector<std::unique_ptr<TransientFile>> files_;
782 };
783
784 int RunInContainer(ContainerFilesystem *fs, const std::string &hostname,
785 const std::string &domainname, VoidToIntFn fn);
786
787 # define ICLASS_NAME(casename, testname) Contained##casename##_##testname
788 # define CONTAINED_TEST_F(casename, testname, hostname, domainname, files) \
789 class ICLASS_NAME(casename, testname) : public casename { \
790 public: \
791 ICLASS_NAME(casename, testname)() \
792 { \
793 } \
794 static int InnerTestBody(); \
795 }; \
796 TEST_F(ICLASS_NAME(casename, testname), _) \
797 { \
798 ContainerFilesystem chroot(files, ".."); \
799 VoidToIntFn fn(ICLASS_NAME(casename, testname)::InnerTestBody); \
800 EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn)); \
801 } \
802 int ICLASS_NAME(casename, testname)::InnerTestBody()
803
804
805 /* Derived from googletest/include/gtest/gtest-param-test.h, specifically the
806 * TEST_P() macro, and some fixes to try to be compatible with different
807 * versions. */
808 # ifndef GTEST_ATTRIBUTE_UNUSED_
809 # define GTEST_ATTRIBUTE_UNUSED_
810 # endif
811 # ifndef GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED
812 # define GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED
813 # endif
814 # define CONTAINED_TEST_P(test_suite_name, test_name, hostname, domainname, \
815 files) \
816 class GTEST_TEST_CLASS_NAME_(test_suite_name, test_name) \
817 : public test_suite_name { \
818 public: \
819 GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)() \
820 { \
821 } \
822 int InnerTestBody(); \
823 void TestBody() \
824 { \
825 ContainerFilesystem chroot(files, ".."); \
826 VoidToIntFn fn = [this](void) -> int { \
827 ares_reinit(this->channel_); \
828 ares_sleep_time(100); \
829 return this->InnerTestBody(); \
830 }; \
831 EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn)); \
832 } \
833 \
834 private: \
835 static int AddToRegistry() \
836 { \
837 ::testing::UnitTest::GetInstance() \
838 ->parameterized_test_registry() \
839 .GetTestSuitePatternHolder<test_suite_name>( \
840 GTEST_STRINGIFY_(test_suite_name), \
841 ::testing::internal::CodeLocation(__FILE__, __LINE__)) \
842 ->AddTestPattern( \
843 GTEST_STRINGIFY_(test_suite_name), GTEST_STRINGIFY_(test_name), \
844 new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
845 test_suite_name, test_name)>(), \
846 ::testing::internal::CodeLocation(__FILE__, __LINE__)); \
847 return 0; \
848 } \
849 GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED static int \
850 gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
851 }; \
852 int GTEST_TEST_CLASS_NAME_(test_suite_name, \
853 test_name)::gtest_registering_dummy_ = \
854 GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::AddToRegistry(); \
855 int GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::InnerTestBody()
856
857 #endif
858
859 /* Assigns virtual IO functions to a channel. These functions simply call
860 * the actual system functions.
861 */
862 class VirtualizeIO {
863 public:
864 VirtualizeIO(ares_channel);
865 ~VirtualizeIO();
866
867 static const ares_socket_functions default_functions;
868
869 private:
870 ares_channel_t *channel_;
871 };
872
873 /*
874 * Slightly white-box macro to generate two runs for a given test case:
875 * One with no modifications, and one with all IO functions set to use
876 * the virtual io structure.
877 * Since no magic socket setup or anything is done in the latter case
878 * this should probably only be used for test with very vanilla IO
879 * requirements.
880 */
881 #define VCLASS_NAME(casename, testname) Virt##casename##_##testname
882 #define VIRT_NONVIRT_TEST_F(casename, testname) \
883 class VCLASS_NAME(casename, testname) : public casename { \
884 public: \
885 VCLASS_NAME(casename, testname)() \
886 { \
887 } \
888 void InnerTestBody(); \
889 }; \
890 GTEST_TEST_(casename, testname, VCLASS_NAME(casename, testname), \
891 ::testing::internal::GetTypeId<casename>()) \
892 { \
893 InnerTestBody(); \
894 } \
895 GTEST_TEST_(casename, testname##_virtualized, \
896 VCLASS_NAME(casename, testname), \
897 ::testing::internal::GetTypeId<casename>()) \
898 { \
899 VirtualizeIO vio(channel_); \
900 InnerTestBody(); \
901 } \
902 void VCLASS_NAME(casename, testname)::InnerTestBody()
903
904 } // namespace test
905 } // namespace ares
906
907 #endif
908