1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
17 #ifndef _WIN32
18 #include <arpa/inet.h>
19 #include <netdb.h>
20 #include <netinet/in.h>
21 #include <sys/socket.h>
22 #else
23 #include <Windows.h>
24 #include <winsock2.h>
25 #include <ws2tcpip.h>
26 #endif
27 #include <sys/types.h>
28
29 namespace tensorflow {
30
31 namespace {
32
33 const std::vector<string>& kCachedDomainNames =
34 *new std::vector<string>{"www.googleapis.com", "storage.googleapis.com"};
35
print_getaddrinfo_error(const string & name,int error_code)36 inline void print_getaddrinfo_error(const string& name, int error_code) {
37 #ifndef _WIN32
38 if (error_code == EAI_SYSTEM) {
39 LOG(ERROR) << "Error resolving " << name
40 << " (EAI_SYSTEM): " << strerror(errno);
41 } else {
42 LOG(ERROR) << "Error resolving " << name << ": "
43 << gai_strerror(error_code);
44 }
45 #else
46 // TODO:WSAGetLastError is better than gai_strerror
47 LOG(ERROR) << "Error resolving " << name << ": " << gai_strerror(error_code);
48 #endif
49 }
50
51 // Selects one item at random from a vector of items, using a uniform
52 // distribution.
53 template <typename T>
SelectRandomItemUniform(std::default_random_engine * random,const std::vector<T> & items)54 const T& SelectRandomItemUniform(std::default_random_engine* random,
55 const std::vector<T>& items) {
56 CHECK_GT(items.size(), 0);
57 std::uniform_int_distribution<size_t> distribution(0u, items.size() - 1u);
58 size_t choice_index = distribution(*random);
59 return items[choice_index];
60 }
61 } // namespace
62
GcsDnsCache(Env * env,int64_t refresh_rate_secs)63 GcsDnsCache::GcsDnsCache(Env* env, int64_t refresh_rate_secs)
64 : env_(env), refresh_rate_secs_(refresh_rate_secs) {}
65
AnnotateRequest(HttpRequest * request)66 void GcsDnsCache::AnnotateRequest(HttpRequest* request) {
67 // TODO(saeta): Denylist failing IP addresses.
68 mutex_lock l(mu_);
69 if (!started_) {
70 VLOG(1) << "Starting GCS DNS cache.";
71 DCHECK(!worker_) << "Worker thread already exists!";
72 // Perform DNS resolutions to warm the cache.
73 addresses_ = ResolveNames(kCachedDomainNames);
74
75 // Note: we opt to use a thread instead of a delayed closure.
76 worker_.reset(env_->StartThread({}, "gcs_dns_worker",
77 [this]() { return WorkerThread(); }));
78 started_ = true;
79 }
80
81 CHECK_EQ(kCachedDomainNames.size(), addresses_.size());
82 for (size_t i = 0; i < kCachedDomainNames.size(); ++i) {
83 const string& name = kCachedDomainNames[i];
84 const std::vector<string>& addresses = addresses_[i];
85 if (!addresses.empty()) {
86 const string& chosen_address =
87 SelectRandomItemUniform(&random_, addresses);
88 request->AddResolveOverride(name, 443, chosen_address);
89 VLOG(1) << "Annotated DNS mapping: " << name << " --> " << chosen_address;
90 } else {
91 LOG(WARNING) << "No IP addresses available for " << name;
92 }
93 }
94 }
95
ResolveName(const string & name)96 /* static */ std::vector<string> GcsDnsCache::ResolveName(const string& name) {
97 VLOG(1) << "Resolving DNS name: " << name;
98
99 addrinfo hints;
100 memset(&hints, 0, sizeof(hints));
101 hints.ai_family = AF_INET; // Only use IPv4 for now.
102 hints.ai_socktype = SOCK_STREAM;
103 addrinfo* result = nullptr;
104 int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result);
105
106 std::vector<string> output;
107 if (return_code == 0) {
108 for (const addrinfo* i = result; i != nullptr; i = i->ai_next) {
109 if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) {
110 LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family
111 << ". sa_family: " << i->ai_addr->sa_family << ".";
112 continue;
113 }
114 char buf[INET_ADDRSTRLEN];
115 void* address_ptr =
116 &(reinterpret_cast<sockaddr_in*>(i->ai_addr)->sin_addr);
117 const char* formatted = nullptr;
118 if ((formatted = inet_ntop(i->ai_addr->sa_family, address_ptr, buf,
119 INET_ADDRSTRLEN)) == nullptr) {
120 LOG(ERROR) << "Error converting response to IP address for " << name
121 << ": " << strerror(errno);
122 } else {
123 output.emplace_back(buf);
124 VLOG(1) << "... address: " << buf;
125 }
126 }
127 } else {
128 print_getaddrinfo_error(name, return_code);
129 }
130 if (result != nullptr) {
131 freeaddrinfo(result);
132 }
133 return output;
134 }
135
136 // Performs DNS resolution for a set of DNS names. The return vector contains
137 // one element for each element in 'names', and each element is itself a
138 // vector of IP addresses (in textual form).
139 //
140 // If DNS resolution fails for any name, then that slot in the return vector
141 // will still be present, but will be an empty vector.
142 //
143 // Ensures: names.size() == return_value.size()
144
ResolveNames(const std::vector<string> & names)145 std::vector<std::vector<string>> GcsDnsCache::ResolveNames(
146 const std::vector<string>& names) {
147 std::vector<std::vector<string>> all_addresses;
148 all_addresses.reserve(names.size());
149 for (const string& name : names) {
150 all_addresses.push_back(ResolveName(name));
151 }
152 return all_addresses;
153 }
154
WorkerThread()155 void GcsDnsCache::WorkerThread() {
156 while (true) {
157 {
158 // Don't immediately re-resolve the addresses.
159 mutex_lock l(mu_);
160 if (cancelled_) return;
161 cond_var_.wait_for(l, std::chrono::seconds(refresh_rate_secs_));
162 if (cancelled_) return;
163 }
164
165 // Resolve DNS values
166 auto new_addresses = ResolveNames(kCachedDomainNames);
167
168 {
169 mutex_lock l(mu_);
170 // Update instance variables.
171 addresses_.swap(new_addresses);
172 }
173 }
174 }
175
176 } // namespace tensorflow
177