• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* MIT License
2  *
3  * Copyright (c) The c-ares project and its contributors
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a copy
6  * of this software and associated documentation files (the "Software"), to deal
7  * in the Software without restriction, including without limitation the rights
8  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9  * copies of the Software, and to permit persons to whom the Software is
10  * furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  *
24  * SPDX-License-Identifier: MIT
25  */
26 // -*- mode: c++ -*-
27 #ifndef ARES_TEST_H
28 #define ARES_TEST_H
29 
30 #include "ares_setup.h"
31 #include "dns-proto.h"
32 // Include ares internal file for DNS protocol constants
33 #include "ares_nameser.h"
34 
35 #include "gtest/gtest.h"
36 #include "gmock/gmock.h"
37 
38 #if defined(HAVE_USER_NAMESPACE) && defined(HAVE_UTS_NAMESPACE)
39 #  define HAVE_CONTAINER
40 #endif
41 
42 #include <functional>
43 #include <list>
44 #include <map>
45 #include <memory>
46 #include <set>
47 #include <string>
48 #include <mutex>
49 #include <thread>
50 #include <utility>
51 #include <vector>
52 #include <chrono>
53 
54 #if defined(HAVE_CLOSESOCKET)
55 #  define sclose(x) closesocket(x)
56 #elif defined(HAVE_CLOSESOCKET_CAMEL)
57 #  define sclose(x) CloseSocket(x)
58 #elif defined(HAVE_CLOSE_S)
59 #  define sclose(x) close_s(x)
60 #else
61 #  define sclose(x) close(x)
62 #endif
63 
64 #ifndef HAVE_WRITEV
65 extern "C" {
66 /* Structure for scatter/gather I/O. */
67 struct iovec {
68   void  *iov_base; /* Pointer to data. */
69   size_t iov_len;  /* Length of data.  */
70 };
71 };
72 #endif
73 
74 namespace ares {
75 
76 typedef unsigned char byte;
77 
78 namespace test {
79 
80 extern bool                                    verbose;
81 extern unsigned short                          mock_port;
82 extern const std::vector<int>                  both_families;
83 extern const std::vector<int>                  ipv4_family;
84 extern const std::vector<int>                  ipv6_family;
85 
86 extern const std::vector<std::pair<int, bool>> both_families_both_modes;
87 extern const std::vector<std::pair<int, bool>> ipv4_family_both_modes;
88 extern const std::vector<std::pair<int, bool>> ipv6_family_both_modes;
89 
90 extern const std::vector<std::tuple<ares_evsys_t, int, bool>>
91   all_evsys_ipv4_family_both_modes;
92 extern const std::vector<std::tuple<ares_evsys_t, int, bool>>
93   all_evsys_ipv6_family_both_modes;
94 extern const std::vector<std::tuple<ares_evsys_t, int, bool>>
95   all_evsys_both_families_both_modes;
96 
97 extern const std::vector<std::tuple<ares_evsys_t, int>> all_evsys_ipv4_family;
98 extern const std::vector<std::tuple<ares_evsys_t, int>> all_evsys_ipv6_family;
99 extern const std::vector<std::tuple<ares_evsys_t, int>> all_evsys_both_families;
100 
101 // Which parameters to use in tests
102 extern std::vector<int>                                 families;
103 extern std::vector<std::tuple<ares_evsys_t, int>>       evsys_families;
104 extern std::vector<std::pair<int, bool>>                families_modes;
105 extern std::vector<std::tuple<ares_evsys_t, int, bool>> evsys_families_modes;
106 
107 // Hopefully a more accurate sleep than sleep_for()
108 void                    ares_sleep_time(unsigned int ms);
109 
110 // Process all pending work on ares-owned file descriptors, plus
111 // optionally the given set-of-FDs + work function.
112 void                    ProcessWork(ares_channel_t                          *channel,
113                                     std::function<std::set<ares_socket_t>()> get_extrafds,
114                                     std::function<void(ares_socket_t)>       process_extra,
115                                     unsigned int                             cancel_ms = 0);
116 std::set<ares_socket_t> NoExtraFDs();
117 
118 const char             *af_tostr(int af);
119 const char             *mode_tostr(bool mode);
120 std::string
121   PrintFamilyMode(const testing::TestParamInfo<std::pair<int, bool>> &info);
122 std::string PrintFamily(const testing::TestParamInfo<int> &info);
123 
124 // Test fixture that ensures library initialization, and allows
125 // memory allocations to be failed.
126 class LibraryTest : public ::testing::Test {
127 public:
LibraryTest()128   LibraryTest()
129   {
130     EXPECT_EQ(ARES_SUCCESS, ares_library_init_mem(
131                               ARES_LIB_INIT_ALL, &LibraryTest::amalloc,
132                               &LibraryTest::afree, &LibraryTest::arealloc));
133   }
134 
~LibraryTest()135   ~LibraryTest()
136   {
137     ares_library_cleanup();
138     ClearFails();
139   }
140 
141   // Set the n-th malloc call (of any size) from the library to fail.
142   // (nth == 1 means the next call)
143   static void  SetAllocFail(int nth);
144   // Set the next malloc call for the given size to fail.
145   static void  SetAllocSizeFail(size_t size);
146   // Remove any pending alloc failures.
147   static void  ClearFails();
148 
149   static void *amalloc(size_t size);
150   static void *arealloc(void *ptr, size_t size);
151   static void  afree(void *ptr);
152 
153   static void SetFailSend(void);
154   static ares_ssize_t ares_sendv_fail(ares_socket_t socket, const struct iovec *vec, int len,
155                                       void *user_data);
156 
157 
158 private:
159   static bool                  ShouldAllocFail(size_t size);
160   static unsigned long long    fails_;
161   static std::map<size_t, int> size_fails_;
162   static std::mutex            lock_;
163   static bool                  failsend_;
164 };
165 
166 // Test fixture that uses a default channel.
167 class DefaultChannelTest : public LibraryTest {
168 public:
DefaultChannelTest()169   DefaultChannelTest() : channel_(nullptr)
170   {
171     /* Enable query cache for live tests */
172     struct ares_options opts;
173     memset(&opts, 0, sizeof(opts));
174     opts.qcache_max_ttl = 300;
175     int optmask         = ARES_OPT_QUERY_CACHE;
176     EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
177     EXPECT_NE(nullptr, channel_);
178   }
179 
~DefaultChannelTest()180   ~DefaultChannelTest()
181   {
182     ares_destroy(channel_);
183     channel_ = nullptr;
184   }
185 
186   // Process all pending work on ares-owned file descriptors.
187   void Process(unsigned int cancel_ms = 0);
188 
189 protected:
190   ares_channel_t *channel_;
191 };
192 
193 // Test fixture that uses a file-only channel.
194 class FileChannelTest : public LibraryTest {
195 public:
FileChannelTest()196   FileChannelTest() : channel_(nullptr)
197   {
198     struct ares_options opts;
199     memset(&opts, 0, sizeof(opts));
200     opts.lookups = strdup("f");
201     int optmask  = ARES_OPT_LOOKUPS;
202     EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
203     EXPECT_NE(nullptr, channel_);
204     free(opts.lookups);
205   }
206 
~FileChannelTest()207   ~FileChannelTest()
208   {
209     ares_destroy(channel_);
210     channel_ = nullptr;
211   }
212 
213   // Process all pending work on ares-owned file descriptors.
214   void Process(unsigned int cancel_ms = 0);
215 
216 protected:
217   ares_channel_t *channel_;
218 };
219 
220 // Test fixture that uses a default channel with the specified lookup mode.
221 class DefaultChannelModeTest
222   : public LibraryTest,
223     public ::testing::WithParamInterface<std::string> {
224 public:
DefaultChannelModeTest()225   DefaultChannelModeTest() : channel_(nullptr)
226   {
227     struct ares_options opts;
228     memset(&opts, 0, sizeof(opts));
229     opts.lookups = strdup(GetParam().c_str());
230     int optmask  = ARES_OPT_LOOKUPS;
231     EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
232     EXPECT_NE(nullptr, channel_);
233     free(opts.lookups);
234   }
235 
~DefaultChannelModeTest()236   ~DefaultChannelModeTest()
237   {
238     ares_destroy(channel_);
239     channel_ = nullptr;
240   }
241 
242   // Process all pending work on ares-owned file descriptors.
243   void Process(unsigned int cancel_ms = 0);
244 
245 protected:
246   ares_channel_t *channel_;
247 };
248 
249 // Mock DNS server to allow responses to be scripted by tests.
250 class MockServer {
251 public:
252   MockServer(int family, unsigned short port);
253   ~MockServer();
254 
255   // Mock method indicating the processing of a particular <name, RRtype>
256   // request.
257   MOCK_METHOD2(OnRequest, void(const std::string &name, int rrtype));
258 
259   // Set the reply to be sent next; the query ID field will be overwritten
260   // with the value from the request.
SetReplyData(const std::vector<byte> & reply)261   void SetReplyData(const std::vector<byte> &reply)
262   {
263     exact_reply_ = reply;
264     reply_       = nullptr;
265   }
266 
SetReply(const DNSPacket * reply)267   void SetReply(const DNSPacket *reply)
268   {
269     reply_ = reply;
270     exact_reply_.clear();
271   }
272 
273   // Set the reply to be sent next as well as the request (in string form) that
274   // the server should expect to receive; the query ID field in the reply will
275   // be overwritten with the value from the request.
SetReplyExpRequest(const DNSPacket * reply,const std::string & request)276   void SetReplyExpRequest(const DNSPacket *reply, const std::string &request)
277   {
278     expected_request_ = request;
279     reply_            = reply;
280   }
281 
SetReplyQID(int qid)282   void SetReplyQID(int qid)
283   {
284     qid_ = qid;
285   }
286 
Disconnect()287   void Disconnect()
288   {
289     reply_ = nullptr;
290     exact_reply_.clear();
291     for (ares_socket_t fd : connfds_) {
292       sclose(fd);
293     }
294     connfds_.clear();
295     free(tcp_data_);
296     tcp_data_     = NULL;
297     tcp_data_len_ = 0;
298   }
299 
300   // The set of file descriptors that the server handles.
301   std::set<ares_socket_t> fds() const;
302 
303   // Process activity on a file descriptor.
304   void                    ProcessFD(ares_socket_t fd);
305 
306   // Ports the server is responding to
udpport()307   unsigned short          udpport() const
308   {
309     return udpport_;
310   }
311 
tcpport()312   unsigned short tcpport() const
313   {
314     return tcpport_;
315   }
316 
317 private:
318   void           ProcessRequest(ares_socket_t fd, struct sockaddr_storage *addr,
319                                 ares_socklen_t addrlen, const std::vector<byte> &req,
320                                 const std::string &reqstr, int qid, const char *name,
321                                 int rrtype);
322   void           ProcessPacket(ares_socket_t fd, struct sockaddr_storage *addr,
323                                ares_socklen_t addrlen, byte *data, int len);
324   unsigned short udpport_;
325   unsigned short tcpport_;
326   ares_socket_t  udpfd_;
327   ares_socket_t  tcpfd_;
328   std::set<ares_socket_t> connfds_;
329   std::vector<byte>       exact_reply_;
330   const DNSPacket        *reply_;
331   std::string             expected_request_;
332   int                     qid_;
333   unsigned char          *tcp_data_;
334   size_t                  tcp_data_len_;
335 };
336 
337 // Test fixture that uses a mock DNS server.
338 class MockChannelOptsTest : public LibraryTest {
339 public:
340   MockChannelOptsTest(int count, int family, bool force_tcp,
341                       bool honor_sysconfig, struct ares_options *givenopts,
342                       int optmask);
343   ~MockChannelOptsTest();
344 
345   // Process all pending work on ares-owned and mock-server-owned file
346   // descriptors.
347   void ProcessAltChannel(ares_channel_t *chan, unsigned int cancel_ms = 0);
348   void Process(unsigned int cancel_ms = 0);
349 
350 protected:
351   // NiceMockServer doesn't complain about uninteresting calls.
352   typedef testing::NiceMock<MockServer>                NiceMockServer;
353   typedef std::vector<std::unique_ptr<NiceMockServer>> NiceMockServers;
354 
355   std::set<ares_socket_t>                              fds() const;
356   void                   ProcessFD(ares_socket_t fd);
357 
358   static NiceMockServers BuildServers(int count, int family,
359                                       unsigned short base_port);
360 
361   NiceMockServers        servers_;
362   // Convenience reference to first server.
363   NiceMockServer        &server_;
364   ares_channel_t        *channel_;
365 };
366 
367 class MockChannelTest
368   : public MockChannelOptsTest,
369     public ::testing::WithParamInterface<std::pair<int, bool>> {
370 public:
MockChannelTest()371   MockChannelTest()
372     : MockChannelOptsTest(1, GetParam().first, GetParam().second, false,
373                           nullptr, 0)
374   {
375   }
376 };
377 
378 class MockUDPChannelTest : public MockChannelOptsTest,
379                            public ::testing::WithParamInterface<int> {
380 public:
MockUDPChannelTest()381   MockUDPChannelTest()
382     : MockChannelOptsTest(1, GetParam(), false, false, nullptr, 0)
383   {
384   }
385 };
386 
387 class MockTCPChannelTest : public MockChannelOptsTest,
388                            public ::testing::WithParamInterface<int> {
389 public:
MockTCPChannelTest()390   MockTCPChannelTest()
391     : MockChannelOptsTest(1, GetParam(), true, false, nullptr, 0)
392   {
393   }
394 };
395 
396 class MockEventThreadOptsTest : public MockChannelOptsTest {
397 public:
MockEventThreadOptsTest(int count,ares_evsys_t evsys,int family,bool force_tcp,struct ares_options * givenopts,int optmask)398   MockEventThreadOptsTest(int count, ares_evsys_t evsys, int family,
399                           bool force_tcp, struct ares_options *givenopts,
400                           int optmask)
401     : MockChannelOptsTest(count, family, force_tcp, false,
402                           FillOptionsET(&evopts_, givenopts, evsys),
403                           optmask | ARES_OPT_EVENT_THREAD)
404   {
405   }
406 
~MockEventThreadOptsTest()407   ~MockEventThreadOptsTest()
408   {
409   }
410 
FillOptionsET(struct ares_options * opts,struct ares_options * givenopts,ares_evsys_t evsys)411   static struct ares_options *FillOptionsET(struct ares_options *opts,
412                                             struct ares_options *givenopts,
413                                             ares_evsys_t         evsys)
414   {
415     if (givenopts) {
416       memcpy(opts, givenopts, sizeof(*opts));
417     } else {
418       memset(opts, 0, sizeof(*opts));
419     }
420     opts->evsys = evsys;
421     return opts;
422   }
423 
424   void Process(unsigned int cancel_ms = 0);
425 
426 private:
427   struct ares_options evopts_;
428 };
429 
430 class MockEventThreadTest
431   : public MockEventThreadOptsTest,
432     public ::testing::WithParamInterface<std::tuple<ares_evsys_t, int, bool>> {
433 public:
MockEventThreadTest()434   MockEventThreadTest()
435     : MockEventThreadOptsTest(1, std::get<0>(GetParam()),
436                               std::get<1>(GetParam()), std::get<2>(GetParam()),
437                               nullptr, 0)
438   {
439   }
440 };
441 
442 class MockUDPEventThreadTest
443   : public MockEventThreadOptsTest,
444     public ::testing::WithParamInterface<std::tuple<ares_evsys_t, int>> {
445 public:
MockUDPEventThreadTest()446   MockUDPEventThreadTest()
447     : MockEventThreadOptsTest(1, std::get<0>(GetParam()),
448                               std::get<1>(GetParam()), false, nullptr, 0)
449   {
450   }
451 };
452 
453 class MockTCPEventThreadTest
454   : public MockEventThreadOptsTest,
455     public ::testing::WithParamInterface<std::tuple<ares_evsys_t, int>> {
456 public:
MockTCPEventThreadTest()457   MockTCPEventThreadTest()
458     : MockEventThreadOptsTest(1, std::get<0>(GetParam()),
459                               std::get<1>(GetParam()), true, nullptr, 0)
460   {
461   }
462 };
463 
464 // gMock action to set the reply for a mock server.
ACTION_P2(SetReplyData,mockserver,data)465 ACTION_P2(SetReplyData, mockserver, data)
466 {
467   mockserver->SetReplyData(data);
468 }
469 
ACTION_P2(SetReplyAndFailSend,mockserver,reply)470 ACTION_P2(SetReplyAndFailSend, mockserver, reply)
471 {
472   mockserver->SetReply(reply);
473   LibraryTest::SetFailSend();
474 }
475 
ACTION_P2(SetReply,mockserver,reply)476 ACTION_P2(SetReply, mockserver, reply)
477 {
478   mockserver->SetReply(reply);
479 }
480 
481 // gMock action to set the reply for a mock server, as well as the request (in
482 // string form) that the server should expect to receive.
ACTION_P3(SetReplyExpRequest,mockserver,reply,request)483 ACTION_P3(SetReplyExpRequest, mockserver, reply, request)
484 {
485   mockserver->SetReplyExpRequest(reply, request);
486 }
487 
ACTION_P2(SetReplyQID,mockserver,qid)488 ACTION_P2(SetReplyQID, mockserver, qid)
489 {
490   mockserver->SetReplyQID(qid);
491 }
492 
493 // gMock action to cancel a channel.
ACTION_P2(CancelChannel,mockserver,channel)494 ACTION_P2(CancelChannel, mockserver, channel)
495 {
496   ares_cancel(channel);
497 }
498 
499 // gMock action to disconnect all connections.
ACTION_P(Disconnect,mockserver)500 ACTION_P(Disconnect, mockserver)
501 {
502   mockserver->Disconnect();
503 }
504 
505 // C++ wrapper for struct hostent.
506 struct HostEnt {
HostEntHostEnt507   HostEnt() : addrtype_(-1)
508   {
509   }
510 
511   HostEnt(const struct hostent *hostent);
512   std::string              name_;
513   std::vector<std::string> aliases_;
514   int                      addrtype_;  // AF_INET or AF_INET6
515   std::vector<std::string> addrs_;
516 };
517 
518 std::ostream &operator<<(std::ostream &os, const HostEnt &result);
519 
520 // Structure that describes the result of an ares_host_callback invocation.
521 struct HostResult {
HostResultHostResult522   HostResult() : done_(false), status_(0), timeouts_(0)
523   {
524   }
525 
526   // Whether the callback has been invoked.
527   bool    done_;
528   // Explicitly provided result information.
529   int     status_;
530   int     timeouts_;
531   // Contents of the hostent structure, if provided.
532   HostEnt host_;
533 };
534 
535 std::ostream &operator<<(std::ostream &os, const HostResult &result);
536 
537 // C++ wrapper for ares_dns_record_t.
538 struct AresDnsRecord {
~AresDnsRecordAresDnsRecord539   ~AresDnsRecord()
540   {
541     ares_dns_record_destroy(dnsrec_);
542     dnsrec_ = NULL;
543   }
544 
AresDnsRecordAresDnsRecord545   AresDnsRecord() : dnsrec_(NULL)
546   {
547   }
548 
SetDnsRecordAresDnsRecord549   void SetDnsRecord(const ares_dns_record_t *dnsrec)
550   {
551     if (dnsrec_ != NULL) {
552       ares_dns_record_destroy(dnsrec_);
553     }
554     if (dnsrec == NULL) {
555       return;
556     }
557     dnsrec_ = ares_dns_record_duplicate(dnsrec);
558   }
559 
560   ares_dns_record_t *dnsrec_ = NULL;
561 };
562 
563 std::ostream &operator<<(std::ostream &os, const AresDnsRecord &result);
564 
565 // Structure that describes the result of an ares_host_callback invocation.
566 struct QueryResult {
QueryResultQueryResult567   QueryResult() : done_(false), status_(ARES_SUCCESS), timeouts_(0)
568   {
569   }
570 
571   // Whether the callback has been invoked.
572   bool          done_;
573   // Explicitly provided result information.
574   ares_status_t status_;
575   size_t        timeouts_;
576   // Contents of the ares_dns_record_t structure if provided
577   AresDnsRecord dnsrec_;
578 };
579 
580 std::ostream &operator<<(std::ostream &os, const QueryResult &result);
581 
582 // Structure that describes the result of an ares_callback invocation.
583 struct SearchResult {
584   // Whether the callback has been invoked.
585   bool              done_;
586   // Explicitly provided result information.
587   int               status_;
588   int               timeouts_;
589   std::vector<byte> data_;
590 };
591 
592 std::ostream &operator<<(std::ostream &os, const SearchResult &result);
593 
594 // Structure that describes the result of an ares_nameinfo_callback invocation.
595 struct NameInfoResult {
596   // Whether the callback has been invoked.
597   bool        done_;
598   // Explicitly provided result information.
599   int         status_;
600   int         timeouts_;
601   std::string node_;
602   std::string service_;
603 };
604 
605 std::ostream &operator<<(std::ostream &os, const NameInfoResult &result);
606 
607 struct AddrInfoDeleter {
operatorAddrInfoDeleter608   void operator()(ares_addrinfo *ptr)
609   {
610     if (ptr) {
611       ares_freeaddrinfo(ptr);
612     }
613   }
614 };
615 
616 // C++ wrapper for struct ares_addrinfo.
617 using AddrInfo = std::unique_ptr<ares_addrinfo, AddrInfoDeleter>;
618 
619 std::ostream &operator<<(std::ostream &os, const AddrInfo &result);
620 
621 // Structure that describes the result of an ares_addrinfo_callback invocation.
622 struct AddrInfoResult {
AddrInfoResultAddrInfoResult623   AddrInfoResult() : done_(false), status_(-1), timeouts_(0)
624   {
625   }
626 
627   // Whether the callback has been invoked.
628   bool     done_;
629   // Explicitly provided result information.
630   int      status_;
631   int      timeouts_;
632   // Contents of the ares_addrinfo structure, if provided.
633   AddrInfo ai_;
634 };
635 
636 std::ostream &operator<<(std::ostream &os, const AddrInfoResult &result);
637 
638 // Standard implementation of ares callbacks that fill out the corresponding
639 // structures.
640 void          HostCallback(void *data, int status, int timeouts,
641                            struct hostent *hostent);
642 void          QueryCallback(void *data, ares_status_t status, size_t timeouts,
643                             const ares_dns_record_t *dnsrec);
644 void SearchCallback(void *data, int status, int timeouts, unsigned char *abuf,
645                     int alen);
646 void SearchCallbackDnsRec(void *data, ares_status_t status, size_t timeouts,
647                           const ares_dns_record_t *dnsrec);
648 void NameInfoCallback(void *data, int status, int timeouts, char *node,
649                       char *service);
650 void AddrInfoCallback(void *data, int status, int timeouts,
651                       struct ares_addrinfo *res);
652 
653 // Retrieve the name servers used by a channel.
654 std::string GetNameServers(ares_channel_t *channel);
655 
656 // RAII class to temporarily create a directory of a given name.
657 class TransientDir {
658 public:
659   TransientDir(const std::string &dirname);
660   ~TransientDir();
661 
662 private:
663   std::string dirname_;
664 };
665 
666 // C++ wrapper around tempnam()
667 std::string TempNam(const char *dir, const char *prefix);
668 
669 // RAII class to temporarily create file of a given name and contents.
670 class TransientFile {
671 public:
672   TransientFile(const std::string &filename, const std::string &contents);
673   ~TransientFile();
674 
675 protected:
676   std::string filename_;
677 };
678 
679 // RAII class for a temporary file with the given contents.
680 class TempFile : public TransientFile {
681 public:
682   TempFile(const std::string &contents);
683 
filename()684   const char *filename() const
685   {
686     return filename_.c_str();
687   }
688 };
689 
690 #ifdef _WIN32
691 extern "C" {
692 
setenv(const char * name,const char * value,int overwrite)693 static int setenv(const char *name, const char *value, int overwrite)
694 {
695   char  *buffer;
696   size_t buf_size;
697 
698   if (name == NULL) {
699     return -1;
700   }
701 
702   if (value == NULL) {
703     value = ""; /* For unset */
704   }
705 
706   if (!overwrite && getenv(name) != NULL) {
707     return -1;
708   }
709 
710   buf_size = strlen(name) + strlen(value) + 1 /* = */ + 1 /* NULL */;
711   buffer   = (char *)malloc(buf_size);
712   _snprintf(buffer, buf_size, "%s=%s", name, value);
713   _putenv(buffer);
714   free(buffer);
715   return 0;
716 }
717 
unsetenv(const char * name)718 static int unsetenv(const char *name)
719 {
720   return setenv(name, NULL, 1);
721 }
722 
723 } /* extern "C" */
724 #endif
725 
726 // RAII class for a temporary environment variable value.
727 class EnvValue {
728 public:
EnvValue(const char * name,const char * value)729   EnvValue(const char *name, const char *value) : name_(name), restore_(false)
730   {
731     char *original = getenv(name);
732     if (original) {
733       restore_  = true;
734       original_ = original;
735     }
736     setenv(name_.c_str(), value, 1);
737   }
738 
~EnvValue()739   ~EnvValue()
740   {
741     if (restore_) {
742       setenv(name_.c_str(), original_.c_str(), 1);
743     } else {
744       unsetenv(name_.c_str());
745     }
746   }
747 
748 private:
749   std::string name_;
750   bool        restore_;
751   std::string original_;
752 };
753 
754 
755 #ifdef HAVE_CONTAINER
756 // Linux-specific functionality for running code in a container, implemented
757 // in ares-test-ns.cc
758 typedef std::function<int(void)>                         VoidToIntFn;
759 typedef std::vector<std::pair<std::string, std::string>> NameContentList;
760 
761 class ContainerFilesystem {
762 public:
763   ContainerFilesystem(NameContentList files, const std::string &mountpt);
764   ~ContainerFilesystem();
765 
root()766   std::string root() const
767   {
768     return rootdir_;
769   }
770 
mountpt()771   std::string mountpt() const
772   {
773     return mountpt_;
774   }
775 
776 private:
777   void                   EnsureDirExists(const std::string &dir);
778   std::string            rootdir_;
779   std::string            mountpt_;
780   std::list<std::string> dirs_;
781   std::vector<std::unique_ptr<TransientFile>> files_;
782 };
783 
784 int RunInContainer(ContainerFilesystem *fs, const std::string &hostname,
785                    const std::string &domainname, VoidToIntFn fn);
786 
787 #  define ICLASS_NAME(casename, testname) Contained##casename##_##testname
788 #  define CONTAINED_TEST_F(casename, testname, hostname, domainname, files)   \
789     class ICLASS_NAME(casename, testname) : public casename {                 \
790     public:                                                                   \
791       ICLASS_NAME(casename, testname)()                                       \
792       {                                                                       \
793       }                                                                       \
794       static int InnerTestBody();                                             \
795     };                                                                        \
796     TEST_F(ICLASS_NAME(casename, testname), _)                                \
797     {                                                                         \
798       ContainerFilesystem chroot(files, "..");                                \
799       VoidToIntFn         fn(ICLASS_NAME(casename, testname)::InnerTestBody); \
800       EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn));        \
801     }                                                                         \
802     int ICLASS_NAME(casename, testname)::InnerTestBody()
803 
804 
805 /* Derived from googletest/include/gtest/gtest-param-test.h, specifically the
806  * TEST_P() macro, and some fixes to try to be compatible with different
807  * versions. */
808 #  ifndef GTEST_ATTRIBUTE_UNUSED_
809 #    define GTEST_ATTRIBUTE_UNUSED_
810 #  endif
811 #  ifndef GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED
812 #    define GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED
813 #  endif
814 #  define CONTAINED_TEST_P(test_suite_name, test_name, hostname, domainname, \
815                            files)                                            \
816     class GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)                 \
817       : public test_suite_name {                                             \
818     public:                                                                  \
819       GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)()                   \
820       {                                                                      \
821       }                                                                      \
822       int  InnerTestBody();                                                  \
823       void TestBody()                                                        \
824       {                                                                      \
825         ContainerFilesystem chroot(files, "..");                             \
826         VoidToIntFn         fn = [this](void) -> int {                       \
827           ares_reinit(this->channel_);                               \
828           ares_sleep_time(100);                                      \
829           return this->InnerTestBody();                              \
830         };                                                                   \
831         EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn));     \
832       }                                                                      \
833                                                                              \
834     private:                                                                 \
835       static int AddToRegistry()                                             \
836       {                                                                      \
837         ::testing::UnitTest::GetInstance()                                   \
838           ->parameterized_test_registry()                                    \
839           .GetTestSuitePatternHolder<test_suite_name>(                       \
840             GTEST_STRINGIFY_(test_suite_name),                               \
841             ::testing::internal::CodeLocation(__FILE__, __LINE__))           \
842           ->AddTestPattern(                                                  \
843             GTEST_STRINGIFY_(test_suite_name), GTEST_STRINGIFY_(test_name),  \
844             new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
845               test_suite_name, test_name)>(),                                \
846             ::testing::internal::CodeLocation(__FILE__, __LINE__));          \
847         return 0;                                                            \
848       }                                                                      \
849       GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED static int                       \
850         gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_;                    \
851     };                                                                       \
852     int GTEST_TEST_CLASS_NAME_(test_suite_name,                              \
853                                test_name)::gtest_registering_dummy_ =        \
854       GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::AddToRegistry();   \
855     int GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::InnerTestBody()
856 
857 #endif
858 
859 /* Assigns virtual IO functions to a channel. These functions simply call
860  * the actual system functions.
861  */
862 class VirtualizeIO {
863 public:
864   VirtualizeIO(ares_channel);
865   ~VirtualizeIO();
866 
867   static const ares_socket_functions default_functions;
868 
869 private:
870   ares_channel_t *channel_;
871 };
872 
873 /*
874  * Slightly white-box macro to generate two runs for a given test case:
875  * One with no modifications, and one with all IO functions set to use
876  * the virtual io structure.
877  * Since no magic socket setup or anything is done in the latter case
878  * this should probably only be used for test with very vanilla IO
879  * requirements.
880  */
881 #define VCLASS_NAME(casename, testname) Virt##casename##_##testname
882 #define VIRT_NONVIRT_TEST_F(casename, testname)                    \
883   class VCLASS_NAME(casename, testname) : public casename {        \
884   public:                                                          \
885     VCLASS_NAME(casename, testname)()                              \
886     {                                                              \
887     }                                                              \
888     void InnerTestBody();                                          \
889   };                                                               \
890   GTEST_TEST_(casename, testname, VCLASS_NAME(casename, testname), \
891               ::testing::internal::GetTypeId<casename>())          \
892   {                                                                \
893     InnerTestBody();                                               \
894   }                                                                \
895   GTEST_TEST_(casename, testname##_virtualized,                    \
896               VCLASS_NAME(casename, testname),                     \
897               ::testing::internal::GetTypeId<casename>())          \
898   {                                                                \
899     VirtualizeIO vio(channel_);                                    \
900     InnerTestBody();                                               \
901   }                                                                \
902   void VCLASS_NAME(casename, testname)::InnerTestBody()
903 
904 }  // namespace test
905 }  // namespace ares
906 
907 #endif
908