1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21
22 #include "grpc++/grpc++.h"
23 #include "grpc++/security/credentials.h"
24 #include "grpc++/server_builder.h"
25 #include "grpc/support/alloc.h"
26
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/process_util.h"
30 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
31 #include "tensorflow/core/distributed_runtime/local_master.h"
32 #include "tensorflow/core/distributed_runtime/master.h"
33 #include "tensorflow/core/distributed_runtime/master_env.h"
34 #include "tensorflow/core/distributed_runtime/master_session.h"
35 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
40 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
41 #include "tensorflow/core/distributed_runtime/server_lib.h"
42 #include "tensorflow/core/distributed_runtime/worker_env.h"
43 #include "tensorflow/core/framework/op.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/env.h"
46 #include "tensorflow/core/platform/mem.h"
47 #include "tensorflow/core/public/session_options.h"
48
49 namespace tensorflow {
50
51 namespace {
52
53 // Define an option subclass in order to disable SO_REUSEPORT for the
54 // server socket.
55 class NoReusePortOption : public ::grpc::ServerBuilderOption {
56 public:
UpdateArguments(::grpc::ChannelArguments * args)57 void UpdateArguments(::grpc::ChannelArguments* args) override {
58 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
59 }
60
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)61 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
62 plugins) override {}
63 };
64
65 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)66 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
67 return new RpcRendezvousMgr(env);
68 }
69
70 } // namespace
71
GrpcServer(const ServerDef & server_def,Env * env)72 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
73 : server_def_(server_def), env_(env), state_(NEW) {}
74
~GrpcServer()75 GrpcServer::~GrpcServer() {
76 TF_CHECK_OK(Stop());
77 TF_CHECK_OK(Join());
78
79 delete master_service_;
80 delete worker_service_;
81
82 // TODO(mrry): Refactor the *Env classes so that it is less fiddly
83 // to destroy them.
84
85 // Shut down all outstanding rendezvous.
86 delete worker_env_.rendezvous_mgr;
87
88 // We must delete graph_mgr before device_mgr, due to shared
89 // ownership of OpKernels in the executors. (The graph_mgr will
90 // free all stateless OpKernels, and pass over borrowed stateful
91 // OpKernels, which are also held in their respective devices'
92 // OpSegments.)
93 if (worker_env_.session_mgr != nullptr) {
94 delete worker_env_.session_mgr; // Deletes graph_mgr's.
95 } else {
96 // Note: session_mgr's legacy_session_ deletes device_mgr now.
97 delete worker_env_.device_mgr;
98 }
99
100 // Do not delete (as these are not owned by the server):
101 // - master_env_.env
102 // - worker_env_.env
103 // - worker_env_.compute_pool
104 }
105
Init(ServiceInitFunction service_func,const RendezvousMgrCreationFunction & rendezvous_mgr_func,const WorkerCreationFunction & worker_func)106 Status GrpcServer::Init(
107 ServiceInitFunction service_func,
108 const RendezvousMgrCreationFunction& rendezvous_mgr_func,
109 const WorkerCreationFunction& worker_func) {
110 mutex_lock l(mu_);
111 CHECK_EQ(state_, NEW);
112 master_env_.env = env_;
113 worker_env_.env = env_;
114
115 SessionOptions sess_opts;
116 ConfigProto config = server_def_.default_session_config();
117 sess_opts.config = config;
118
119 // Configure shared devices between master and worker.
120 string name_prefix =
121 strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
122 "/task:", server_def_.task_index());
123 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
124 &master_env_.local_devices));
125 worker_env_.local_devices = master_env_.local_devices;
126 worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
127 worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
128 ? new RpcRendezvousMgr(&worker_env_)
129 : rendezvous_mgr_func(&worker_env_);
130 string unused;
131 string default_worker_name;
132 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
133 &default_worker_name, &unused)) {
134 return errors::Internal("Could not parse worker name.");
135 }
136
137 // Look up the port that has been requested for this task in `server_def_`.
138 int requested_port = -1;
139 for (const auto& job : server_def_.cluster().job()) {
140 if (job.name() == server_def_.job_name()) {
141 auto iter = job.tasks().find(server_def_.task_index());
142 if (iter == job.tasks().end()) {
143 return errors::InvalidArgument("Task ", server_def_.task_index(),
144 " was not defined in job \"",
145 server_def_.job_name(), "\"");
146 }
147 const std::vector<string> hostname_port =
148 str_util::Split(iter->second, ':');
149 if (hostname_port.size() != 2 ||
150 !strings::safe_strto32(hostname_port[1], &requested_port)) {
151 return errors::InvalidArgument(
152 "Could not parse port for local server from \"", iter->second,
153 "\"");
154 } else {
155 break;
156 }
157 }
158 }
159 if (requested_port == -1) {
160 return errors::Internal("Job \"", server_def_.job_name(),
161 "\" was not defined in cluster");
162 }
163
164 // N.B. The order of initialization here is intricate, because we
165 // wish to allow `requested_port == 0` (for choosing any port,
166 // mostly for testing). Therefore, the construction of the channel
167 // and worker caches depends on `bound_port_`, which is not set
168 // until we call `builder.BuildAndStart()`. We must create the
169 // service objects before calling `builder.BuildAndStart()`, but
170 // `master_env_` and `worker_env_` are only partially
171 // configured. However, this is not dangerous, because we do not
172 // start serving requests until `this->Start()` is called, which
173 // happens after this method returns.
174 //
175 // TODO(mrry): Provide a general mechanism for dynamically setting
176 // the identities of tasks in the worker pool after the service is
177 // running.
178 ::grpc::ServerBuilder builder;
179 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
180 GetServerCredentials(server_def_), &bound_port_);
181 builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
182 builder.SetOption(
183 std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
184 master_impl_ = CreateMaster(&master_env_);
185 master_service_ = NewGrpcMasterService(
186 master_impl_.get(), config.operation_timeout_in_ms(), &builder);
187 worker_impl_ =
188 worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
189 worker_service_ =
190 NewGrpcWorkerService(worker_impl_.get(), &builder).release();
191 // extra service:
192 if (service_func != nullptr) {
193 service_func(&worker_env_, &builder);
194 }
195 server_ = builder.BuildAndStart();
196
197 if (!server_) {
198 return errors::Unknown("Could not start gRPC server");
199 }
200
201 WorkerCacheInterface* worker_cache;
202 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
203 TF_RETURN_IF_ERROR(
204 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
205 CHECK_NE(nullptr, worker_cache);
206
207 // Set up worker environment.
208 worker_env_.session_mgr = new SessionMgr(
209 &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
210 std::unique_ptr<WorkerCacheInterface>(worker_cache),
211 [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
212 WorkerCacheFactoryOptions options(server_def);
213 return WorkerCacheFactory(options, worker_cache);
214 });
215 worker_env_.compute_pool = ComputePool(sess_opts);
216
217 // Finish setting up master environment.
218 master_env_.ops = OpRegistry::Global();
219 master_env_.worker_cache = worker_cache;
220 master_env_.master_session_factory =
221 [config](
222 SessionOptions options, const MasterEnv* env,
223 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
224 std::unique_ptr<WorkerCacheInterface> worker_cache,
225 std::unique_ptr<DeviceSet> device_set) {
226 options.config.MergeFrom(config);
227 return new MasterSession(options, env, std::move(remote_devs),
228 std::move(worker_cache), std::move(device_set),
229 CreateNoOpStatsPublisher);
230 };
231 master_env_.worker_cache_factory =
232 [this](const WorkerCacheFactoryOptions& options,
233 WorkerCacheInterface** worker_cache) {
234 return WorkerCacheFactory(options, worker_cache);
235 };
236
237 // Provide direct access to the master from in-process clients.
238 LocalMaster::Register(target(), master_impl_.get(),
239 config.operation_timeout_in_ms());
240
241 return Status::OK();
242 }
243
Init(ServiceInitFunction service_func,const RendezvousMgrCreationFunction & rendezvous_mgr_func)244 Status GrpcServer::Init(
245 ServiceInitFunction service_func,
246 const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
247 return Init(service_func, rendezvous_mgr_func, nullptr);
248 }
249
Init()250 Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
251
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)252 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
253 GrpcChannelSpec* channel_spec) {
254 for (const auto& job : options.cluster_def->job()) {
255 std::map<int, string> host_ports;
256 for (const auto& task : job.tasks()) {
257 string& host_port = host_ports[task.first];
258 if (!host_port.empty()) {
259 return errors::InvalidArgument("JobDef for job \"", job.name(),
260 "\" specified two addresses for task \"",
261 task.first, "\": ", host_port, " and ",
262 task.second);
263 }
264 if (job.name() == *options.job_name && task.first == options.task_index) {
265 host_port = strings::StrCat("localhost:", bound_port_);
266 } else {
267 host_port = task.second;
268 }
269 }
270 TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
271 }
272 return Status::OK();
273 }
274
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)275 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
276 WorkerCacheInterface** worker_cache) {
277 if (options.job_name == nullptr || options.job_name->empty()) {
278 Status s = errors::InvalidArgument(
279 "The master (current machine) is not included in the provided "
280 "cluster_def. ",
281 options.cluster_def->DebugString());
282 LOG(WARNING) << s;
283 return s;
284 }
285
286 GrpcChannelSpec channel_spec;
287 TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
288
289 std::unique_ptr<GrpcChannelCache> channel_cache(
290 NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
291
292 string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
293 "/task:", options.task_index);
294
295 const string host_port = channel_cache->TranslateTask(name_prefix);
296 int requested_port;
297
298 if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
299 &requested_port)) {
300 return errors::Internal("Could not parse port for local server from \"",
301 channel_cache->TranslateTask(name_prefix), "\".");
302 }
303 if (requested_port != bound_port_) {
304 return errors::InvalidArgument("Requested port ", requested_port,
305 " differs from expected port ", bound_port_);
306 }
307
308 *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
309 channel_cache.release(), worker_impl_.get(), name_prefix);
310 return Status::OK();
311 }
312
Start()313 Status GrpcServer::Start() {
314 mutex_lock l(mu_);
315 switch (state_) {
316 case NEW: {
317 master_thread_.reset(
318 env_->StartThread(ThreadOptions(), "TF_master_service",
319 [this] { master_service_->HandleRPCsLoop(); }));
320 worker_thread_.reset(
321 env_->StartThread(ThreadOptions(), "TF_worker_service",
322 [this] { worker_service_->HandleRPCsLoop(); }));
323 state_ = STARTED;
324 LOG(INFO) << "Started server with target: " << target();
325 return Status::OK();
326 }
327 case STARTED:
328 LOG(INFO) << "Server already started (target: " << target() << ")";
329 return Status::OK();
330 case STOPPED:
331 return errors::FailedPrecondition("Server has stopped.");
332 default:
333 LOG(FATAL);
334 }
335 }
336
Stop()337 Status GrpcServer::Stop() {
338 mutex_lock l(mu_);
339 switch (state_) {
340 case NEW:
341 state_ = STOPPED;
342 return Status::OK();
343 case STARTED:
344 return errors::Unimplemented(
345 "Clean shutdown is not currently implemented");
346 case STOPPED:
347 LOG(INFO) << "Server already stopped (target: " << target() << ")";
348 return Status::OK();
349 default:
350 LOG(FATAL);
351 }
352 }
353
Join()354 Status GrpcServer::Join() {
355 mutex_lock l(mu_);
356 switch (state_) {
357 case NEW:
358 // Prevent the server from being started subsequently.
359 state_ = STOPPED;
360 return Status::OK();
361 case STARTED:
362 case STOPPED:
363 master_thread_.reset();
364 worker_thread_.reset();
365 return Status::OK();
366 default:
367 LOG(FATAL);
368 }
369 }
370
target() const371 const string GrpcServer::target() const {
372 return strings::StrCat("grpc://localhost:", bound_port_);
373 }
374
GetServerCredentials(const ServerDef & server_def) const375 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
376 const ServerDef& server_def) const {
377 return ::grpc::InsecureServerCredentials();
378 }
379
GetChannelCreationFunction() const380 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
381 // We can do this because SparseGrpcChannelCache is robust to nullptr being
382 // returned by the channel creation function
383 return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
384 }
385
CreateMaster(MasterEnv * master_env)386 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
387 return std::unique_ptr<Master>(new Master(master_env, 0.0));
388 }
389
390 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)391 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
392 std::unique_ptr<ServerInterface>* out_server) {
393 std::unique_ptr<GrpcServer> ret(
394 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
395 ServiceInitFunction service_func = nullptr;
396 TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
397 *out_server = std::move(ret);
398 return Status::OK();
399 }
400
401 namespace {
402
403 class GrpcServerFactory : public ServerFactory {
404 public:
AcceptsOptions(const ServerDef & server_def)405 bool AcceptsOptions(const ServerDef& server_def) override {
406 return server_def.protocol() == "grpc";
407 }
408
NewServer(const ServerDef & server_def,std::unique_ptr<ServerInterface> * out_server)409 Status NewServer(const ServerDef& server_def,
410 std::unique_ptr<ServerInterface>* out_server) override {
411 return GrpcServer::Create(server_def, Env::Default(), out_server);
412 }
413 };
414
415 // Registers a `ServerFactory` for `GrpcServer` instances.
416 class GrpcServerRegistrar {
417 public:
GrpcServerRegistrar()418 GrpcServerRegistrar() {
419 gpr_allocation_functions alloc_fns;
420 memset(&alloc_fns, 0, sizeof(alloc_fns));
421 alloc_fns.malloc_fn = port::Malloc;
422 alloc_fns.realloc_fn = port::Realloc;
423 alloc_fns.free_fn = port::Free;
424 gpr_set_allocation_functions(alloc_fns);
425 ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
426 }
427 };
428 static GrpcServerRegistrar registrar;
429
430 } // namespace
431 } // namespace tensorflow
432