• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // -*- mode: c++ -*-
2 #ifndef ARES_TEST_H
3 #define ARES_TEST_H
4 
5 #include "dns-proto.h"
6 // Include ares internal file for DNS protocol constants
7 #include "nameser.h"
8 
9 #include "ares_setup.h"
10 #include "ares.h"
11 
12 #include "gtest/gtest.h"
13 #include "gmock/gmock.h"
14 
15 #ifdef HAVE_CONFIG_H
16 #include "config.h"
17 #endif
18 #if defined(HAVE_USER_NAMESPACE) && defined(HAVE_UTS_NAMESPACE)
19 #define HAVE_CONTAINER
20 #endif
21 
22 #include <functional>
23 #include <list>
24 #include <map>
25 #include <memory>
26 #include <set>
27 #include <string>
28 #include <utility>
29 #include <vector>
30 
31 namespace ares {
32 
33 typedef unsigned char byte;
34 
35 namespace test {
36 
37 extern bool verbose;
38 extern int mock_port;
39 extern const std::vector<int> both_families;
40 extern const std::vector<int> ipv4_family;
41 extern const std::vector<int> ipv6_family;
42 
43 extern const std::vector<std::pair<int, bool>> both_families_both_modes;
44 extern const std::vector<std::pair<int, bool>> ipv4_family_both_modes;
45 extern const std::vector<std::pair<int, bool>> ipv6_family_both_modes;
46 
47 // Which parameters to use in tests
48 extern std::vector<int> families;
49 extern std::vector<std::pair<int, bool>> families_modes;
50 
51 // Process all pending work on ares-owned file descriptors, plus
52 // optionally the given set-of-FDs + work function.
53 void ProcessWork(ares_channel channel,
54                  std::function<std::set<int>()> get_extrafds,
55                  std::function<void(int)> process_extra);
56 std::set<int> NoExtraFDs();
57 
58 // Test fixture that ensures library initialization, and allows
59 // memory allocations to be failed.
60 class LibraryTest : public ::testing::Test {
61  public:
LibraryTest()62   LibraryTest() {
63     EXPECT_EQ(ARES_SUCCESS,
64               ares_library_init_mem(ARES_LIB_INIT_ALL,
65                                     &LibraryTest::amalloc,
66                                     &LibraryTest::afree,
67                                     &LibraryTest::arealloc));
68   }
~LibraryTest()69   ~LibraryTest() {
70     ares_library_cleanup();
71     ClearFails();
72   }
73   // Set the n-th malloc call (of any size) from the library to fail.
74   // (nth == 1 means the next call)
75   static void SetAllocFail(int nth);
76   // Set the next malloc call for the given size to fail.
77   static void SetAllocSizeFail(size_t size);
78   // Remove any pending alloc failures.
79   static void ClearFails();
80 
81   static void *amalloc(size_t size);
82   static void* arealloc(void *ptr, size_t size);
83   static void afree(void *ptr);
84  private:
85   static bool ShouldAllocFail(size_t size);
86   static unsigned long long fails_;
87   static std::map<size_t, int> size_fails_;
88 };
89 
90 // Test fixture that uses a default channel.
91 class DefaultChannelTest : public LibraryTest {
92  public:
DefaultChannelTest()93   DefaultChannelTest() : channel_(nullptr) {
94     EXPECT_EQ(ARES_SUCCESS, ares_init(&channel_));
95     EXPECT_NE(nullptr, channel_);
96   }
97 
~DefaultChannelTest()98   ~DefaultChannelTest() {
99     ares_destroy(channel_);
100     channel_ = nullptr;
101   }
102 
103   // Process all pending work on ares-owned file descriptors.
104   void Process();
105 
106  protected:
107   ares_channel channel_;
108 };
109 
110 // Test fixture that uses a default channel with the specified lookup mode.
111 class DefaultChannelModeTest
112     : public LibraryTest,
113       public ::testing::WithParamInterface<std::string> {
114  public:
DefaultChannelModeTest()115   DefaultChannelModeTest() : channel_(nullptr) {
116     struct ares_options opts = {0};
117     opts.lookups = strdup(GetParam().c_str());
118     int optmask = ARES_OPT_LOOKUPS;
119     EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
120     EXPECT_NE(nullptr, channel_);
121     free(opts.lookups);
122   }
123 
~DefaultChannelModeTest()124   ~DefaultChannelModeTest() {
125     ares_destroy(channel_);
126     channel_ = nullptr;
127   }
128 
129   // Process all pending work on ares-owned file descriptors.
130   void Process();
131 
132  protected:
133   ares_channel channel_;
134 };
135 
136 // Mock DNS server to allow responses to be scripted by tests.
137 class MockServer {
138  public:
139   MockServer(int family, int port, int tcpport = 0);
140   ~MockServer();
141 
142   // Mock method indicating the processing of a particular <name, RRtype>
143   // request.
144   MOCK_METHOD2(OnRequest, void(const std::string& name, int rrtype));
145 
146   // Set the reply to be sent next; the query ID field will be overwritten
147   // with the value from the request.
SetReplyData(const std::vector<byte> & reply)148   void SetReplyData(const std::vector<byte>& reply) { reply_ = reply; }
SetReply(const DNSPacket * reply)149   void SetReply(const DNSPacket* reply) { SetReplyData(reply->data()); }
SetReplyQID(int qid)150   void SetReplyQID(int qid) { qid_ = qid; }
151 
152   // The set of file descriptors that the server handles.
153   std::set<int> fds() const;
154 
155   // Process activity on a file descriptor.
156   void ProcessFD(int fd);
157 
158   // Ports the server is responding to
udpport()159   int udpport() const { return udpport_; }
tcpport()160   int tcpport() const { return tcpport_; }
161 
162  private:
163   void ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen,
164                       int qid, const std::string& name, int rrtype);
165 
166   int udpport_;
167   int tcpport_;
168   int udpfd_;
169   int tcpfd_;
170   std::set<int> connfds_;
171   std::vector<byte> reply_;
172   int qid_;
173 };
174 
175 // Test fixture that uses a mock DNS server.
176 class MockChannelOptsTest : public LibraryTest {
177  public:
178   MockChannelOptsTest(int count, int family, bool force_tcp, struct ares_options* givenopts, int optmask);
179   ~MockChannelOptsTest();
180 
181   // Process all pending work on ares-owned and mock-server-owned file descriptors.
182   void Process();
183 
184  protected:
185   // NiceMockServer doesn't complain about uninteresting calls.
186   typedef testing::NiceMock<MockServer> NiceMockServer;
187   typedef std::vector< std::unique_ptr<NiceMockServer> > NiceMockServers;
188 
189   std::set<int> fds() const;
190   void ProcessFD(int fd);
191 
192   static NiceMockServers BuildServers(int count, int family, int base_port);
193 
194   NiceMockServers servers_;
195   // Convenience reference to first server.
196   NiceMockServer& server_;
197   ares_channel channel_;
198 };
199 
200 class MockChannelTest
201     : public MockChannelOptsTest,
202       public ::testing::WithParamInterface< std::pair<int, bool> > {
203  public:
MockChannelTest()204   MockChannelTest() : MockChannelOptsTest(1, GetParam().first, GetParam().second, nullptr, 0) {}
205 };
206 
207 class MockUDPChannelTest
208     : public MockChannelOptsTest,
209       public ::testing::WithParamInterface<int> {
210  public:
MockUDPChannelTest()211   MockUDPChannelTest() : MockChannelOptsTest(1, GetParam(), false, nullptr, 0) {}
212 };
213 
214 class MockTCPChannelTest
215     : public MockChannelOptsTest,
216       public ::testing::WithParamInterface<int> {
217  public:
MockTCPChannelTest()218   MockTCPChannelTest() : MockChannelOptsTest(1, GetParam(), true, nullptr, 0) {}
219 };
220 
221 // gMock action to set the reply for a mock server.
ACTION_P2(SetReplyData,mockserver,data)222 ACTION_P2(SetReplyData, mockserver, data) {
223   mockserver->SetReplyData(data);
224 }
ACTION_P2(SetReply,mockserver,reply)225 ACTION_P2(SetReply, mockserver, reply) {
226   mockserver->SetReply(reply);
227 }
ACTION_P2(SetReplyQID,mockserver,qid)228 ACTION_P2(SetReplyQID, mockserver, qid) {
229   mockserver->SetReplyQID(qid);
230 }
231 // gMock action to cancel a channel.
ACTION_P2(CancelChannel,mockserver,channel)232 ACTION_P2(CancelChannel, mockserver, channel) {
233   ares_cancel(channel);
234 }
235 
236 // C++ wrapper for struct hostent.
237 struct HostEnt {
HostEntHostEnt238   HostEnt() : addrtype_(-1) {}
239   HostEnt(const struct hostent* hostent);
240   std::string name_;
241   std::vector<std::string> aliases_;
242   int addrtype_;  // AF_INET or AF_INET6
243   std::vector<std::string> addrs_;
244 };
245 std::ostream& operator<<(std::ostream& os, const HostEnt& result);
246 
247 // Structure that describes the result of an ares_host_callback invocation.
248 struct HostResult {
249   // Whether the callback has been invoked.
250   bool done_;
251   // Explicitly provided result information.
252   int status_;
253   int timeouts_;
254   // Contents of the hostent structure, if provided.
255   HostEnt host_;
256 };
257 std::ostream& operator<<(std::ostream& os, const HostResult& result);
258 
259 // Structure that describes the result of an ares_callback invocation.
260 struct SearchResult {
261   // Whether the callback has been invoked.
262   bool done_;
263   // Explicitly provided result information.
264   int status_;
265   int timeouts_;
266   std::vector<byte> data_;
267 };
268 std::ostream& operator<<(std::ostream& os, const SearchResult& result);
269 
270 // Structure that describes the result of an ares_nameinfo_callback invocation.
271 struct NameInfoResult {
272   // Whether the callback has been invoked.
273   bool done_;
274   // Explicitly provided result information.
275   int status_;
276   int timeouts_;
277   std::string node_;
278   std::string service_;
279 };
280 std::ostream& operator<<(std::ostream& os, const NameInfoResult& result);
281 
282 struct AddrInfoDeleter {
operatorAddrInfoDeleter283   void operator() (ares_addrinfo *ptr) {
284     if (ptr) ares_freeaddrinfo(ptr);
285   }
286 };
287 
288 // C++ wrapper for struct ares_addrinfo.
289 using AddrInfo = std::unique_ptr<ares_addrinfo, AddrInfoDeleter>;
290 
291 std::ostream& operator<<(std::ostream& os, const AddrInfo& result);
292 
293 // Structure that describes the result of an ares_addrinfo_callback invocation.
294 struct AddrInfoResult {
AddrInfoResultAddrInfoResult295   AddrInfoResult() : done_(false), status_(-1), timeouts_(0) {}
296   // Whether the callback has been invoked.
297   bool done_;
298   // Explicitly provided result information.
299   int status_;
300   int timeouts_;
301   // Contents of the ares_addrinfo structure, if provided.
302   AddrInfo ai_;
303 };
304 std::ostream& operator<<(std::ostream& os, const AddrInfoResult& result);
305 
306 // Standard implementation of ares callbacks that fill out the corresponding
307 // structures.
308 void HostCallback(void *data, int status, int timeouts,
309                   struct hostent *hostent);
310 void SearchCallback(void *data, int status, int timeouts,
311                     unsigned char *abuf, int alen);
312 void NameInfoCallback(void *data, int status, int timeouts,
313                       char *node, char *service);
314 void AddrInfoCallback(void *data, int status, int timeouts,
315                       struct ares_addrinfo *res);
316 
317 // Retrieve the name servers used by a channel.
318 std::vector<std::string> GetNameServers(ares_channel channel);
319 
320 
321 // RAII class to temporarily create a directory of a given name.
322 class TransientDir {
323  public:
324   TransientDir(const std::string& dirname);
325   ~TransientDir();
326 
327  private:
328   std::string dirname_;
329 };
330 
331 // C++ wrapper around tempnam()
332 std::string TempNam(const char *dir, const char *prefix);
333 
334 // RAII class to temporarily create file of a given name and contents.
335 class TransientFile {
336  public:
337   TransientFile(const std::string &filename, const std::string &contents);
338   ~TransientFile();
339 
340  protected:
341   std::string filename_;
342 };
343 
344 // RAII class for a temporary file with the given contents.
345 class TempFile : public TransientFile {
346  public:
347   TempFile(const std::string& contents);
filename()348   const char* filename() const { return filename_.c_str(); }
349 };
350 
351 #ifdef _WIN32
352 extern "C" {
353 
setenv(const char * name,const char * value,int overwrite)354 static int setenv(const char *name, const char *value, int overwrite)
355 {
356   char  *buffer;
357   size_t buf_size;
358 
359   if (name == NULL)
360     return -1;
361 
362   if (value == NULL)
363     value = ""; /* For unset */
364 
365   if (!overwrite && getenv(name) != NULL) {
366     return -1;
367   }
368 
369   buf_size = strlen(name) + strlen(value) + 1 /* = */ + 1 /* NULL */;
370   buffer   = (char *)malloc(buf_size);
371   _snprintf(buffer, buf_size, "%s=%s", name, value);
372   _putenv(buffer);
373   free(buffer);
374   return 0;
375 }
376 
unsetenv(const char * name)377 static int unsetenv(const char *name)
378 {
379   return setenv(name, NULL, 1);
380 }
381 
382 } /* extern "C" */
383 #endif
384 
385 // RAII class for a temporary environment variable value.
386 class EnvValue {
387  public:
EnvValue(const char * name,const char * value)388   EnvValue(const char *name, const char *value) : name_(name), restore_(false) {
389     char *original = getenv(name);
390     if (original) {
391       restore_ = true;
392       original_ = original;
393     }
394     setenv(name_.c_str(), value, 1);
395   }
~EnvValue()396   ~EnvValue() {
397     if (restore_) {
398       setenv(name_.c_str(), original_.c_str(), 1);
399     } else {
400       unsetenv(name_.c_str());
401     }
402   }
403  private:
404   std::string name_;
405   bool restore_;
406   std::string original_;
407 };
408 
409 
410 #ifdef HAVE_CONTAINER
411 // Linux-specific functionality for running code in a container, implemented
412 // in ares-test-ns.cc
413 typedef std::function<int(void)> VoidToIntFn;
414 typedef std::vector<std::pair<std::string, std::string>> NameContentList;
415 
416 class ContainerFilesystem {
417  public:
418   ContainerFilesystem(NameContentList files, const std::string& mountpt);
419   ~ContainerFilesystem();
root()420   std::string root() const { return rootdir_; };
mountpt()421   std::string mountpt() const { return mountpt_; };
422  private:
423   void EnsureDirExists(const std::string& dir);
424   std::string rootdir_;
425   std::string mountpt_;
426   std::list<std::string> dirs_;
427   std::vector<std::unique_ptr<TransientFile>> files_;
428 };
429 
430 int RunInContainer(ContainerFilesystem* fs, const std::string& hostname,
431                    const std::string& domainname, VoidToIntFn fn);
432 
433 #define ICLASS_NAME(casename, testname) Contained##casename##_##testname
434 #define CONTAINED_TEST_F(casename, testname, hostname, domainname, files)       \
435   class ICLASS_NAME(casename, testname) : public casename {                     \
436    public:                                                                      \
437     ICLASS_NAME(casename, testname)() {}                                        \
438     static int InnerTestBody();                                                 \
439   };                                                                            \
440   TEST_F(ICLASS_NAME(casename, testname), _) {                                  \
441     ContainerFilesystem chroot(files, "..");                                    \
442     VoidToIntFn fn(ICLASS_NAME(casename, testname)::InnerTestBody);             \
443     EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn));            \
444   }                                                                             \
445   int ICLASS_NAME(casename, testname)::InnerTestBody()
446 
447 #endif
448 
449 /* Assigns virtual IO functions to a channel. These functions simply call
450  * the actual system functions.
451  */
452 class VirtualizeIO {
453 public:
454   VirtualizeIO(ares_channel);
455   ~VirtualizeIO();
456 
457   static const ares_socket_functions default_functions;
458 private:
459   ares_channel channel_;
460 };
461 
462 /*
463  * Slightly white-box macro to generate two runs for a given test case:
464  * One with no modifications, and one with all IO functions set to use
465  * the virtual io structure.
466  * Since no magic socket setup or anything is done in the latter case
467  * this should probably only be used for test with very vanilla IO
468  * requirements.
469  */
470 #define VCLASS_NAME(casename, testname) Virt##casename##_##testname
471 #define VIRT_NONVIRT_TEST_F(casename, testname)                                 \
472   class VCLASS_NAME(casename, testname) : public casename {                     \
473   public:                                                                       \
474     VCLASS_NAME(casename, testname)() {}                                        \
475     void InnerTestBody();                                                       \
476   };                                                                            \
477   GTEST_TEST_(casename, testname, VCLASS_NAME(casename, testname),              \
478               ::testing::internal::GetTypeId<casename>()) {                     \
479     InnerTestBody();                                                            \
480   }                                                                             \
481   GTEST_TEST_(casename, testname##_virtualized,                                 \
482               VCLASS_NAME(casename, testname),                                  \
483               ::testing::internal::GetTypeId<casename>()) {                     \
484     VirtualizeIO vio(channel_);                                                 \
485     InnerTestBody();                                                            \
486   }                                                                             \
487   void VCLASS_NAME(casename, testname)::InnerTestBody()
488 
489 }  // namespace test
490 }  // namespace ares
491 
492 #endif
493