1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "grpc/support/alloc.h"
24 #include "grpcpp/grpcpp.h"
25 #include "grpcpp/security/credentials.h"
26 #include "grpcpp/server_builder.h"
27
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
32 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
33 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
34 #include "tensorflow/core/distributed_runtime/local_master.h"
35 #include "tensorflow/core/distributed_runtime/master.h"
36 #include "tensorflow/core/distributed_runtime/master_env.h"
37 #include "tensorflow/core/distributed_runtime/master_session.h"
38 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
39 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
42 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
43 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
44 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
45 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
46 #include "tensorflow/core/distributed_runtime/server_lib.h"
47 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
48 #include "tensorflow/core/distributed_runtime/worker_env.h"
49 #include "tensorflow/core/framework/op.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/mem.h"
53 #include "tensorflow/core/public/session_options.h"
54
55 namespace tensorflow {
56
57 namespace {
58
59 // Define an option subclass in order to disable SO_REUSEPORT for the
60 // server socket.
61 class NoReusePortOption : public ::grpc::ServerBuilderOption {
62 public:
UpdateArguments(::grpc::ChannelArguments * args)63 void UpdateArguments(::grpc::ChannelArguments* args) override {
64 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
65 }
66
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)67 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
68 plugins) override {}
69 };
70
71 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)72 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
73 return new RpcRendezvousMgr(env);
74 }
75
76 } // namespace
77
GrpcServer(const ServerDef & server_def,Env * env)78 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
79 : server_def_(server_def), env_(env), state_(NEW) {}
80
~GrpcServer()81 GrpcServer::~GrpcServer() {
82 TF_CHECK_OK(Stop());
83 TF_CHECK_OK(Join());
84
85 delete master_service_;
86 delete worker_service_;
87 delete eager_service_;
88
89 // TODO(mrry): Refactor the *Env classes so that it is less fiddly
90 // to destroy them.
91
92 // Shut down all outstanding rendezvous.
93 delete worker_env_.rendezvous_mgr;
94
95 // We must delete graph_mgr before device_mgr, due to shared
96 // ownership of OpKernels in the executors. (The graph_mgr will
97 // free all stateless OpKernels, and pass over borrowed stateful
98 // OpKernels, which are also held in their respective devices'
99 // OpSegments.)
100 if (worker_env_.session_mgr != nullptr) {
101 delete worker_env_.session_mgr; // Deletes graph_mgr's.
102 } else {
103 // Note: session_mgr's legacy_session_ deletes device_mgr now.
104 delete worker_env_.device_mgr;
105 }
106
107 // Do not delete (as these are not owned by the server):
108 // - master_env_.env
109 // - worker_env_.env
110 // - worker_env_.compute_pool
111 }
112
MaybeMutateBuilder(::grpc::ServerBuilder * builder)113 void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
114
Init(const GrpcServerOptions & opts)115 Status GrpcServer::Init(const GrpcServerOptions& opts) {
116 mutex_lock l(mu_);
117 CHECK_EQ(state_, NEW);
118 master_env_.env = env_;
119 worker_env_.env = env_;
120
121 // Check parameters before DeviceFactory::AddDevices,
122 // otherwise if 'task_index=-1' the program will abort.
123
124 // Look up the port that has been requested for this task in `server_def_`.
125 int requested_port = -1;
126 for (const auto& job : server_def_.cluster().job()) {
127 if (job.name() == server_def_.job_name()) {
128 auto iter = job.tasks().find(server_def_.task_index());
129 if (iter == job.tasks().end()) {
130 return errors::InvalidArgument("Task ", server_def_.task_index(),
131 " was not defined in job \"",
132 server_def_.job_name(), "\"");
133 }
134 auto colon_index = iter->second.find_last_of(':');
135 if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
136 &requested_port)) {
137 return errors::InvalidArgument(
138 "Could not parse port for local server from \"", iter->second,
139 "\".");
140 }
141 break;
142 }
143 }
144 if (requested_port == -1) {
145 return errors::Internal("Job \"", server_def_.job_name(),
146 "\" was not defined in cluster");
147 }
148
149 SessionOptions sess_opts;
150 ConfigProto config = server_def_.default_session_config();
151 sess_opts.config = config;
152
153 // Configure shared devices between master and worker.
154 string name_prefix =
155 strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
156 "/task:", server_def_.task_index());
157 std::vector<std::unique_ptr<Device>> devices;
158 TF_RETURN_IF_ERROR(
159 DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
160 worker_env_.device_mgr = new DeviceMgr(std::move(devices));
161 master_env_.local_devices = worker_env_.device_mgr->ListDevices();
162 worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
163 worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
164 ? new RpcRendezvousMgr(&worker_env_)
165 : opts.rendezvous_mgr_func(&worker_env_);
166 string unused;
167 string default_worker_name;
168 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
169 &default_worker_name, &unused)) {
170 return errors::Internal("Could not parse worker name.");
171 }
172
173 // N.B. The order of initialization here is intricate, because we
174 // wish to allow `requested_port == 0` (for choosing any port,
175 // mostly for testing). Therefore, the construction of the channel
176 // and worker caches depends on `bound_port_`, which is not set
177 // until we call `builder.BuildAndStart()`. We must create the
178 // service objects before calling `builder.BuildAndStart()`, but
179 // `master_env_` and `worker_env_` are only partially
180 // configured. However, this is not dangerous, because we do not
181 // start serving requests until `this->Start()` is called, which
182 // happens after this method returns.
183 //
184 // TODO(mrry): Provide a general mechanism for dynamically setting
185 // the identities of tasks in the worker pool after the service is
186 // running.
187 ::grpc::ServerBuilder builder;
188 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
189 GetServerCredentials(server_def_), &bound_port_);
190 builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
191
192 builder.SetOption(
193 std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
194 // Allow subclasses to specify more args to pass to the gRPC server.
195 MaybeMutateBuilder(&builder);
196 master_impl_ = CreateMaster(&master_env_);
197 master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
198 worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
199 : NewGrpcWorker(&worker_env_, config);
200 worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
201 opts.worker_service_options)
202 .release();
203 eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
204
205 // extra service:
206 if (opts.service_func != nullptr) {
207 opts.service_func(&worker_env_, &builder);
208 }
209 server_ = builder.BuildAndStart();
210
211 if (!server_) {
212 return errors::Unknown("Could not start gRPC server");
213 }
214
215 WorkerCacheInterface* worker_cache;
216 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
217 TF_RETURN_IF_ERROR(
218 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
219 CHECK_NE(nullptr, worker_cache);
220
221 if (opts.collective_mgr_func) {
222 worker_env_.collective_executor_mgr =
223 opts.collective_mgr_func(config, &worker_env_, worker_cache);
224 if (!worker_env_.collective_executor_mgr) {
225 return errors::Internal(
226 "collective_mgr_func did not return CollectiveExecutorMgr");
227 }
228 } else {
229 std::unique_ptr<DeviceResolverDistributed> dev_resolver(
230 new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
231 default_worker_name));
232 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
233 new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
234 dev_resolver.get(), worker_cache,
235 default_worker_name));
236 worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
237 config, worker_env_.device_mgr, std::move(dev_resolver),
238 std::move(param_resolver), worker_cache, default_worker_name);
239 }
240
241 // Set up worker environment.
242 worker_env_.session_mgr = new SessionMgr(
243 &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
244 std::unique_ptr<WorkerCacheInterface>(worker_cache),
245 [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
246 WorkerCacheFactoryOptions options(server_def);
247 return WorkerCacheFactory(options, worker_cache);
248 });
249 worker_env_.compute_pool = ComputePool(sess_opts);
250
251 // Finish setting up master environment.
252 master_env_.ops = OpRegistry::Global();
253 master_env_.worker_cache = worker_cache;
254 master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
255 StatsPublisherFactory stats_factory = opts.stats_factory;
256 master_env_.master_session_factory =
257 [config, stats_factory](
258 SessionOptions options, const MasterEnv* env,
259 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
260 std::unique_ptr<WorkerCacheInterface> worker_cache,
261 std::unique_ptr<DeviceSet> device_set,
262 std::vector<string> filtered_worker_list) {
263 options.config.MergeFrom(config);
264 return new MasterSession(options, env, std::move(remote_devs),
265 std::move(worker_cache), std::move(device_set),
266 std::move(filtered_worker_list),
267 stats_factory);
268 };
269 master_env_.worker_cache_factory =
270 [this](const WorkerCacheFactoryOptions& options,
271 WorkerCacheInterface** worker_cache) {
272 return WorkerCacheFactory(options, worker_cache);
273 };
274
275 // Provide direct access to the master from in-process clients.
276 LocalMaster::Register(target(), master_impl_.get(),
277 config.operation_timeout_in_ms());
278
279 return Status::OK();
280 }
281
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)282 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
283 GrpcChannelSpec* channel_spec) {
284 for (const auto& job : options.cluster_def->job()) {
285 std::map<int, string> host_ports;
286 for (const auto& task : job.tasks()) {
287 string& host_port = host_ports[task.first];
288 if (!host_port.empty()) {
289 return errors::InvalidArgument("JobDef for job \"", job.name(),
290 "\" specified two addresses for task \"",
291 task.first, "\": ", host_port, " and ",
292 task.second);
293 }
294 if (job.name() == *options.job_name && task.first == options.task_index) {
295 host_port = strings::StrCat("localhost:", bound_port_);
296 } else {
297 host_port = task.second;
298 }
299 }
300 TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
301 }
302 return Status::OK();
303 }
304
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)305 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
306 WorkerCacheInterface** worker_cache) {
307 if (options.job_name == nullptr || options.job_name->empty()) {
308 Status s = errors::InvalidArgument(
309 "The master (current machine) is not included in the provided "
310 "cluster_def. ",
311 options.cluster_def->DebugString());
312 LOG(WARNING) << s;
313 return s;
314 }
315
316 GrpcChannelSpec channel_spec;
317 TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
318
319 channel_cache_.reset(
320 NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
321
322 string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
323 "/task:", options.task_index);
324
325 const string host_port = channel_cache_->TranslateTask(name_prefix);
326 int requested_port;
327
328 auto colon_index = host_port.find_last_of(':');
329 if (!strings::safe_strto32(host_port.substr(colon_index + 1),
330 &requested_port)) {
331 return errors::Internal("Could not parse port for local server from \"",
332 host_port, "\".");
333 }
334
335 if (requested_port != bound_port_) {
336 return errors::InvalidArgument("Requested port ", requested_port,
337 " differs from expected port ", bound_port_);
338 }
339
340 *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
341 channel_cache_, worker_impl_.get(), name_prefix);
342 return Status::OK();
343 }
344
Start()345 Status GrpcServer::Start() {
346 mutex_lock l(mu_);
347 switch (state_) {
348 case NEW: {
349 master_thread_.reset(
350 env_->StartThread(ThreadOptions(), "TF_master_service",
351 [this] { master_service_->HandleRPCsLoop(); }));
352 worker_thread_.reset(
353 env_->StartThread(ThreadOptions(), "TF_worker_service",
354 [this] { worker_service_->HandleRPCsLoop(); }));
355 eager_thread_.reset(
356 env_->StartThread(ThreadOptions(), "TF_eager_service",
357 [this] { eager_service_->HandleRPCsLoop(); }));
358 state_ = STARTED;
359 LOG(INFO) << "Started server with target: " << target();
360 return Status::OK();
361 }
362 case STARTED:
363 LOG(INFO) << "Server already started (target: " << target() << ")";
364 return Status::OK();
365 case STOPPED:
366 return errors::FailedPrecondition("Server has stopped.");
367 default:
368 LOG(FATAL);
369 }
370 }
371
Stop()372 Status GrpcServer::Stop() {
373 mutex_lock l(mu_);
374 switch (state_) {
375 case NEW:
376 state_ = STOPPED;
377 return Status::OK();
378 case STARTED:
379 return errors::Unimplemented(
380 "Clean shutdown is not currently implemented");
381 case STOPPED:
382 LOG(INFO) << "Server already stopped (target: " << target() << ")";
383 return Status::OK();
384 default:
385 LOG(FATAL);
386 }
387 }
388
Join()389 Status GrpcServer::Join() {
390 mutex_lock l(mu_);
391 switch (state_) {
392 case NEW:
393 // Prevent the server from being started subsequently.
394 state_ = STOPPED;
395 return Status::OK();
396 case STARTED:
397 case STOPPED:
398 master_thread_.reset();
399 worker_thread_.reset();
400 eager_thread_.reset();
401 return Status::OK();
402 default:
403 LOG(FATAL);
404 }
405 }
406
target() const407 const string GrpcServer::target() const {
408 return strings::StrCat("grpc://localhost:", bound_port_);
409 }
410
GetServerCredentials(const ServerDef & server_def) const411 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
412 const ServerDef& server_def) const {
413 return ::grpc::InsecureServerCredentials();
414 }
415
GetChannelCreationFunction() const416 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
417 // We can do this because SparseGrpcChannelCache is robust to nullptr being
418 // returned by the channel creation function
419 return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
420 }
421
CreateMaster(MasterEnv * master_env)422 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
423 return std::unique_ptr<Master>(new Master(master_env, 0.0));
424 }
425
426 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)427 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
428 std::unique_ptr<ServerInterface>* out_server) {
429 std::unique_ptr<GrpcServer> ret(
430 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
431 ServiceInitFunction service_func = nullptr;
432 GrpcServerOptions options;
433 options.rendezvous_mgr_func = NewRpcRendezvousMgr;
434 Status s = ret->Init(options);
435 if (!s.ok()) {
436 LOG(ERROR) << s;
437 return s;
438 }
439 *out_server = std::move(ret);
440 return Status::OK();
441 }
442
443 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)444 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
445 std::unique_ptr<GrpcServer>* out_server) {
446 std::unique_ptr<GrpcServer> ret(
447 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
448 GrpcServerOptions options;
449 options.rendezvous_mgr_func = NewRpcRendezvousMgr;
450 Status s = ret->Init(options);
451 if (!s.ok()) {
452 LOG(ERROR) << s;
453 return s;
454 }
455 *out_server = std::move(ret);
456 return Status::OK();
457 }
458
459 namespace {
460
461 class GrpcServerFactory : public ServerFactory {
462 public:
AcceptsOptions(const ServerDef & server_def)463 bool AcceptsOptions(const ServerDef& server_def) override {
464 return server_def.protocol() == "grpc";
465 }
466
NewServer(const ServerDef & server_def,std::unique_ptr<ServerInterface> * out_server)467 Status NewServer(const ServerDef& server_def,
468 std::unique_ptr<ServerInterface>* out_server) override {
469 return GrpcServer::Create(server_def, Env::Default(), out_server);
470 }
471 };
472
473 // Registers a `ServerFactory` for `GrpcServer` instances.
474 class GrpcServerRegistrar {
475 public:
GrpcServerRegistrar()476 GrpcServerRegistrar() {
477 gpr_allocation_functions alloc_fns;
478 memset(&alloc_fns, 0, sizeof(alloc_fns));
479 alloc_fns.malloc_fn = port::Malloc;
480 alloc_fns.realloc_fn = port::Realloc;
481 alloc_fns.free_fn = port::Free;
482 gpr_set_allocation_functions(alloc_fns);
483 ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
484 }
485 };
486 static GrpcServerRegistrar registrar;
487
488 } // namespace
489 } // namespace tensorflow
490