• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
17 #ifndef _WIN32
18 #include <arpa/inet.h>
19 #include <netdb.h>
20 #include <netinet/in.h>
21 #include <sys/socket.h>
22 #else
23 #include <Windows.h>
24 #include <winsock2.h>
25 #include <ws2tcpip.h>
26 #endif
27 #include <sys/types.h>
28 
29 namespace tensorflow {
30 
31 namespace {
32 
33 const std::vector<string>& kCachedDomainNames =
34     *new std::vector<string>{"www.googleapis.com", "storage.googleapis.com"};
35 
print_getaddrinfo_error(const string & name,int error_code)36 inline void print_getaddrinfo_error(const string& name, int error_code) {
37 #ifndef _WIN32
38   if (error_code == EAI_SYSTEM) {
39     LOG(ERROR) << "Error resolving " << name
40                << " (EAI_SYSTEM): " << strerror(errno);
41   } else {
42     LOG(ERROR) << "Error resolving " << name << ": "
43                << gai_strerror(error_code);
44   }
45 #else
46   // TODO:WSAGetLastError is better than gai_strerror
47   LOG(ERROR) << "Error resolving " << name << ": " << gai_strerror(error_code);
48 #endif
49 }
50 
51 // Selects one item at random from a vector of items, using a uniform
52 // distribution.
53 template <typename T>
SelectRandomItemUniform(std::default_random_engine * random,const std::vector<T> & items)54 const T& SelectRandomItemUniform(std::default_random_engine* random,
55                                  const std::vector<T>& items) {
56   CHECK_GT(items.size(), 0);
57   std::uniform_int_distribution<size_t> distribution(0u, items.size() - 1u);
58   size_t choice_index = distribution(*random);
59   return items[choice_index];
60 }
61 }  // namespace
62 
GcsDnsCache(Env * env,int64_t refresh_rate_secs)63 GcsDnsCache::GcsDnsCache(Env* env, int64_t refresh_rate_secs)
64     : env_(env), refresh_rate_secs_(refresh_rate_secs) {}
65 
AnnotateRequest(HttpRequest * request)66 void GcsDnsCache::AnnotateRequest(HttpRequest* request) {
67   // TODO(saeta): Denylist failing IP addresses.
68   mutex_lock l(mu_);
69   if (!started_) {
70     VLOG(1) << "Starting GCS DNS cache.";
71     DCHECK(!worker_) << "Worker thread already exists!";
72     // Perform DNS resolutions to warm the cache.
73     addresses_ = ResolveNames(kCachedDomainNames);
74 
75     // Note: we opt to use a thread instead of a delayed closure.
76     worker_.reset(env_->StartThread({}, "gcs_dns_worker",
77                                     [this]() { return WorkerThread(); }));
78     started_ = true;
79   }
80 
81   CHECK_EQ(kCachedDomainNames.size(), addresses_.size());
82   for (size_t i = 0; i < kCachedDomainNames.size(); ++i) {
83     const string& name = kCachedDomainNames[i];
84     const std::vector<string>& addresses = addresses_[i];
85     if (!addresses.empty()) {
86       const string& chosen_address =
87           SelectRandomItemUniform(&random_, addresses);
88       request->AddResolveOverride(name, 443, chosen_address);
89       VLOG(1) << "Annotated DNS mapping: " << name << " --> " << chosen_address;
90     } else {
91       LOG(WARNING) << "No IP addresses available for " << name;
92     }
93   }
94 }
95 
ResolveName(const string & name)96 /* static */ std::vector<string> GcsDnsCache::ResolveName(const string& name) {
97   VLOG(1) << "Resolving DNS name: " << name;
98 
99   addrinfo hints;
100   memset(&hints, 0, sizeof(hints));
101   hints.ai_family = AF_INET;  // Only use IPv4 for now.
102   hints.ai_socktype = SOCK_STREAM;
103   addrinfo* result = nullptr;
104   int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result);
105 
106   std::vector<string> output;
107   if (return_code == 0) {
108     for (const addrinfo* i = result; i != nullptr; i = i->ai_next) {
109       if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) {
110         LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family
111                      << ". sa_family: " << i->ai_addr->sa_family << ".";
112         continue;
113       }
114       char buf[INET_ADDRSTRLEN];
115       void* address_ptr =
116           &(reinterpret_cast<sockaddr_in*>(i->ai_addr)->sin_addr);
117       const char* formatted = nullptr;
118       if ((formatted = inet_ntop(i->ai_addr->sa_family, address_ptr, buf,
119                                  INET_ADDRSTRLEN)) == nullptr) {
120         LOG(ERROR) << "Error converting response to IP address for " << name
121                    << ": " << strerror(errno);
122       } else {
123         output.emplace_back(buf);
124         VLOG(1) << "... address: " << buf;
125       }
126     }
127   } else {
128     print_getaddrinfo_error(name, return_code);
129   }
130   if (result != nullptr) {
131     freeaddrinfo(result);
132   }
133   return output;
134 }
135 
136 // Performs DNS resolution for a set of DNS names. The return vector contains
137 // one element for each element in 'names', and each element is itself a
138 // vector of IP addresses (in textual form).
139 //
140 // If DNS resolution fails for any name, then that slot in the return vector
141 // will still be present, but will be an empty vector.
142 //
143 // Ensures: names.size() == return_value.size()
144 
ResolveNames(const std::vector<string> & names)145 std::vector<std::vector<string>> GcsDnsCache::ResolveNames(
146     const std::vector<string>& names) {
147   std::vector<std::vector<string>> all_addresses;
148   all_addresses.reserve(names.size());
149   for (const string& name : names) {
150     all_addresses.push_back(ResolveName(name));
151   }
152   return all_addresses;
153 }
154 
WorkerThread()155 void GcsDnsCache::WorkerThread() {
156   while (true) {
157     {
158       // Don't immediately re-resolve the addresses.
159       mutex_lock l(mu_);
160       if (cancelled_) return;
161       cond_var_.wait_for(l, std::chrono::seconds(refresh_rate_secs_));
162       if (cancelled_) return;
163     }
164 
165     // Resolve DNS values
166     auto new_addresses = ResolveNames(kCachedDomainNames);
167 
168     {
169       mutex_lock l(mu_);
170       // Update instance variables.
171       addresses_.swap(new_addresses);
172     }
173   }
174 }
175 
176 }  // namespace tensorflow
177