1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "grpcpp/server_builder.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
30 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
31 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
32 #include "tensorflow/core/distributed_runtime/local_master.h"
33 #include "tensorflow/core/distributed_runtime/master.h"
34 #include "tensorflow/core/distributed_runtime/master_env.h"
35 #include "tensorflow/core/distributed_runtime/master_session.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
42 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
43 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
44 #include "tensorflow/core/distributed_runtime/server_lib.h"
45 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
46 #include "tensorflow/core/distributed_runtime/worker_env.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/platform/cpu_info.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/mem.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/public/session_options.h"
55 #include "tensorflow/core/util/env_var.h"
56
57 namespace tensorflow {
58
59 namespace {
60
61 // Define an option subclass in order to disable SO_REUSEPORT for the
62 // server socket.
63 class NoReusePortOption : public ::grpc::ServerBuilderOption {
64 public:
UpdateArguments(::grpc::ChannelArguments * args)65 void UpdateArguments(::grpc::ChannelArguments* args) override {
66 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
67 }
68
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)69 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
70 plugins) override {}
71 };
72
73 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)74 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
75 return new RpcRendezvousMgr(env);
76 }
77
CreateGrpcWorkerEnv()78 std::unique_ptr<GrpcWorkerEnv> CreateGrpcWorkerEnv() {
79 int num_cpus = port::NumSchedulableCPUs();
80 int64 num_completion_queues;
81 Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
82 &num_completion_queues);
83 if (!status.ok()) {
84 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
85 }
86 int64 num_threads;
87 status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
88 &num_threads);
89 if (!status.ok()) {
90 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
91 }
92 return absl::make_unique<GrpcWorkerEnv>(num_completion_queues, num_threads);
93 }
94
95 } // namespace
96
GrpcServer(const ServerDef & server_def,Env * env)97 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
98 : env_(env), state_(NEW), server_def_(server_def) {}
99
~GrpcServer()100 GrpcServer::~GrpcServer() {
101 TF_CHECK_OK(Stop());
102 TF_CHECK_OK(Join());
103
104 delete master_service_;
105 delete worker_service_;
106 delete eager_service_;
107
108 // TODO(mrry): Refactor the *Env classes so that it is less fiddly
109 // to destroy them.
110
111 // Shut down all outstanding rendezvous.
112 delete worker_env_.rendezvous_mgr;
113
114 // We must delete graph_mgr before device_mgr, due to shared
115 // ownership of OpKernels in the executors. (The graph_mgr will
116 // free all stateless OpKernels, and pass over borrowed stateful
117 // OpKernels, which are also held in their respective devices'
118 // OpSegments.)
119 if (worker_env_.session_mgr != nullptr) {
120 delete worker_env_.session_mgr; // Deletes graph_mgr's.
121 } else {
122 // Note: session_mgr's legacy_session_ deletes device_mgr now.
123 delete worker_env_.device_mgr;
124 }
125
126 // Do not delete (as these are not owned by the server):
127 // - master_env_.env
128 // - worker_env_.env
129 // - worker_env_.compute_pool
130 }
131
MaybeMutateBuilder(::grpc::ServerBuilder * builder)132 void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
133
134 // Look up the port that has been requested for this task in `server_def`.
GetPort(const ServerDef & server_def,int * port) const135 Status GrpcServer::GetPort(const ServerDef& server_def, int* port) const {
136 *port = -1;
137 for (const auto& job : server_def.cluster().job()) {
138 if (job.name() == server_def.job_name()) {
139 auto iter = job.tasks().find(server_def.task_index());
140 if (iter == job.tasks().end()) {
141 return errors::InvalidArgument("Task ", server_def.task_index(),
142 " was not defined in job \"",
143 server_def.job_name(), "\"");
144 }
145
146 if (server_def.port() != 0) {
147 *port = server_def.port();
148 } else {
149 auto colon_index = iter->second.find_last_of(':');
150 if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
151 port)) {
152 return errors::InvalidArgument(
153 "Could not parse port for local server from \"", iter->second,
154 "\".");
155 }
156 }
157 break;
158 }
159 }
160 if (*port == -1) {
161 return errors::Internal("Job \"", server_def.job_name(),
162 "\" was not defined in cluster");
163 }
164
165 return Status::OK();
166 }
167
Init(const GrpcServerOptions & opts)168 Status GrpcServer::Init(const GrpcServerOptions& opts) {
169 mutex_lock l(mu_);
170 CHECK_EQ(state_, NEW);
171 master_env_.env = env_;
172 worker_env_.env = env_;
173
174 // Check parameters before DeviceFactory::AddDevices,
175 // otherwise if 'task_index=-1' the program will abort.
176
177 int requested_port;
178 TF_RETURN_IF_ERROR(GetPort(server_def_, &requested_port));
179
180 SessionOptions sess_opts;
181 ConfigProto config = server_def_.default_session_config();
182 sess_opts.config = config;
183
184 // Configure shared devices between master and worker.
185 string name_prefix =
186 strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
187 "/task:", server_def_.task_index());
188 std::vector<std::unique_ptr<Device>> devices;
189 TF_RETURN_IF_ERROR(
190 DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
191 worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
192 master_env_.local_devices = worker_env_.device_mgr->ListDevices();
193 worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
194 worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
195 ? new RpcRendezvousMgr(&worker_env_)
196 : opts.rendezvous_mgr_func(&worker_env_);
197 string unused;
198 string default_worker_name;
199 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
200 &default_worker_name, &unused)) {
201 return errors::Internal("Could not parse worker name.");
202 }
203
204 // N.B. The order of initialization here is intricate, because we
205 // wish to allow `requested_port == 0` (for choosing any port,
206 // mostly for testing). Therefore, the construction of the channel
207 // and worker caches depends on `bound_port_`, which is not set
208 // until we call `builder.BuildAndStart()`. We must create the
209 // service objects before calling `builder.BuildAndStart()`, but
210 // `master_env_` and `worker_env_` are only partially
211 // configured. However, this is not dangerous, because we do not
212 // start serving requests until `this->Start()` is called, which
213 // happens after this method returns.
214 //
215 // TODO(mrry): Provide a general mechanism for dynamically setting
216 // the identities of tasks in the worker pool after the service is
217 // running.
218 ::grpc::ServerBuilder builder;
219 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
220 GetServerCredentials(server_def_), &bound_port_);
221 builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
222
223 builder.SetOption(
224 std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
225 // Allow subclasses to specify more args to pass to the gRPC server.
226 MaybeMutateBuilder(&builder);
227 master_impl_ = CreateMaster(&master_env_);
228 master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
229 worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
230 : NewGrpcWorker(&worker_env_, config);
231 worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
232 opts.worker_service_options)
233 .release();
234 eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
235
236 // extra service:
237 if (opts.service_func != nullptr) {
238 opts.service_func(&worker_env_, &builder);
239 }
240 server_ = builder.BuildAndStart();
241
242 if (!server_) {
243 return errors::Unknown("Could not start gRPC server");
244 }
245 // Create the execution environment for the GRPC workers cache.
246 grpc_worker_env_ = CreateGrpcWorkerEnv();
247
248 WorkerCacheInterface* worker_cache;
249 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
250 TF_RETURN_IF_ERROR(
251 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
252 CHECK_NE(nullptr, worker_cache);
253
254 if (opts.collective_mgr_func) {
255 worker_env_.collective_executor_mgr =
256 opts.collective_mgr_func(config, &worker_env_, worker_cache);
257 if (!worker_env_.collective_executor_mgr) {
258 return errors::Internal(
259 "collective_mgr_func did not return CollectiveExecutorMgr");
260 }
261 } else {
262 std::unique_ptr<DeviceResolverDistributed> dev_resolver(
263 new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
264 default_worker_name));
265 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
266 new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
267 dev_resolver.get(), worker_cache,
268 default_worker_name));
269 worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
270 config, worker_env_.device_mgr, std::move(dev_resolver),
271 std::move(param_resolver), worker_cache, default_worker_name);
272 }
273
274 // Set up worker environment.
275 worker_env_.session_mgr = new SessionMgr(
276 &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
277 std::unique_ptr<WorkerCacheInterface>(worker_cache),
278 [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
279 WorkerCacheFactoryOptions options(server_def);
280 return WorkerCacheFactory(options, worker_cache);
281 });
282 worker_env_.compute_pool = ComputePool(sess_opts);
283
284 // Finish setting up master environment.
285 master_env_.ops = OpRegistry::Global();
286 master_env_.worker_cache = worker_cache;
287 master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
288 StatsPublisherFactory stats_factory = opts.stats_factory;
289 master_env_.master_session_factory =
290 [config, stats_factory](
291 SessionOptions options, const MasterEnv* env,
292 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
293 std::unique_ptr<WorkerCacheInterface> worker_cache,
294 std::unique_ptr<DeviceSet> device_set,
295 std::vector<string> filtered_worker_list) {
296 options.config.MergeFrom(config);
297 return new MasterSession(options, env, std::move(remote_devs),
298 std::move(worker_cache), std::move(device_set),
299 std::move(filtered_worker_list),
300 stats_factory);
301 };
302 master_env_.worker_cache_factory =
303 [this](const WorkerCacheFactoryOptions& options,
304 WorkerCacheInterface** worker_cache) {
305 return WorkerCacheFactory(options, worker_cache);
306 };
307
308 // Provide direct access to the master from in-process clients.
309 LocalMaster::Register(target(), master_impl_.get(),
310 config.operation_timeout_in_ms());
311
312 return Status::OK();
313 }
314
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)315 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
316 GrpcChannelSpec* channel_spec) {
317 for (const auto& job : options.cluster_def->job()) {
318 std::map<int, string> host_ports;
319 for (const auto& task : job.tasks()) {
320 string& host_port = host_ports[task.first];
321 if (!host_port.empty()) {
322 return errors::InvalidArgument("JobDef for job \"", job.name(),
323 "\" specified two addresses for task \"",
324 task.first, "\": ", host_port, " and ",
325 task.second);
326 }
327 if (job.name() == *options.job_name && task.first == options.task_index) {
328 host_port = strings::StrCat("localhost:", bound_port_);
329 } else {
330 host_port = task.second;
331 }
332 }
333 TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
334 }
335 return Status::OK();
336 }
337
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)338 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
339 WorkerCacheInterface** worker_cache) {
340 if (options.job_name == nullptr || options.job_name->empty()) {
341 Status s = errors::InvalidArgument(
342 "The master (current machine) is not included in the provided "
343 "cluster_def. ",
344 options.cluster_def->DebugString());
345 LOG(WARNING) << s;
346 return s;
347 }
348
349 GrpcChannelSpec channel_spec;
350 TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
351
352 std::shared_ptr<GrpcChannelCache> channel_cache(
353 NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
354
355 string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
356 "/task:", options.task_index);
357
358 const string host_port = channel_cache->TranslateTask(name_prefix);
359 int requested_port;
360
361 auto colon_index = host_port.find_last_of(':');
362 if (!strings::safe_strto32(host_port.substr(colon_index + 1),
363 &requested_port)) {
364 return errors::Internal("Could not parse port for local server from \"",
365 host_port, "\".");
366 }
367 if (requested_port != bound_port_) {
368 return errors::InvalidArgument("Requested port ", requested_port,
369 " differs from expected port ", bound_port_);
370 }
371 *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
372 channel_cache, worker_impl(), name_prefix, grpc_worker_env_.get());
373 return Status::OK();
374 }
375
Start()376 Status GrpcServer::Start() {
377 mutex_lock l(mu_);
378 switch (state_) {
379 case NEW: {
380 master_thread_.reset(
381 env_->StartThread(ThreadOptions(), "TF_master_service",
382 [this] { master_service_->HandleRPCsLoop(); }));
383 worker_thread_.reset(
384 env_->StartThread(ThreadOptions(), "TF_worker_service",
385 [this] { worker_service_->HandleRPCsLoop(); }));
386 eager_thread_.reset(
387 env_->StartThread(ThreadOptions(), "TF_eager_service",
388 [this] { eager_service_->HandleRPCsLoop(); }));
389 state_ = STARTED;
390 LOG(INFO) << "Started server with target: " << target();
391 return Status::OK();
392 }
393 case STARTED:
394 LOG(INFO) << "Server already started (target: " << target() << ")";
395 return Status::OK();
396 case STOPPED:
397 return errors::FailedPrecondition("Server has stopped.");
398 default:
399 LOG(FATAL);
400 }
401 }
402
AddMasterEagerContextToEagerService(const tensorflow::uint64 context_id,tensorflow::EagerContext * context)403 Status GrpcServer::AddMasterEagerContextToEagerService(
404 const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
405 auto* eager_service =
406 static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
407 return eager_service->CreateMasterContext(context_id, context);
408 }
409
UpdateServerDef(const ServerDef & server_def)410 Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
411 mutex_lock l(mu_);
412 server_def_ = server_def;
413 WorkerCacheInterface* worker_cache;
414 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
415 TF_RETURN_IF_ERROR(
416 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
417 if (worker_cache == nullptr) {
418 return errors::InvalidArgument(
419 "Failed to build worker cache with the provided server def.");
420 }
421
422 string default_worker_name;
423 string unused;
424 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
425 &default_worker_name, &unused)) {
426 return errors::Internal("Could not parse worker name.");
427 }
428 std::unique_ptr<DeviceResolverDistributed> dev_resolver(
429 new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
430 default_worker_name));
431 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
432 new CollectiveParamResolverDistributed(
433 server_def_.default_session_config(), worker_env_.device_mgr,
434 dev_resolver.get(), worker_cache, default_worker_name));
435 worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
436 server_def_.default_session_config(), worker_env_.device_mgr,
437 std::move(dev_resolver), std::move(param_resolver), worker_cache,
438 default_worker_name);
439
440 master_env_.worker_cache = worker_cache;
441 master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
442 return Status::OK();
443 }
444
Stop()445 Status GrpcServer::Stop() {
446 mutex_lock l(mu_);
447 switch (state_) {
448 case NEW:
449 state_ = STOPPED;
450 return Status::OK();
451 case STARTED:
452 return errors::Unimplemented(
453 "Clean shutdown is not currently implemented");
454 case STOPPED:
455 LOG(INFO) << "Server already stopped (target: " << target() << ")";
456 return Status::OK();
457 default:
458 LOG(FATAL);
459 }
460 }
461
Join()462 Status GrpcServer::Join() {
463 mutex_lock l(mu_);
464 switch (state_) {
465 case NEW:
466 // Prevent the server from being started subsequently.
467 state_ = STOPPED;
468 return Status::OK();
469 case STARTED:
470 case STOPPED:
471 master_thread_.reset();
472 worker_thread_.reset();
473 eager_thread_.reset();
474 return Status::OK();
475 default:
476 LOG(FATAL);
477 }
478 }
479
target() const480 const string GrpcServer::target() const {
481 return strings::StrCat("grpc://localhost:", bound_port_);
482 }
483
GetServerCredentials(const ServerDef & server_def) const484 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
485 const ServerDef& server_def) const {
486 return ::grpc::InsecureServerCredentials();
487 }
488
GetChannelCreationFunction() const489 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
490 // We can do this because SparseGrpcChannelCache is robust to nullptr being
491 // returned by the channel creation function
492 return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
493 }
494
CreateMaster(MasterEnv * master_env)495 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
496 return std::unique_ptr<Master>(new Master(master_env, 0.0));
497 }
498
499 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)500 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
501 std::unique_ptr<ServerInterface>* out_server) {
502 std::unique_ptr<GrpcServer> ret(
503 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
504 ServiceInitFunction service_func = nullptr;
505 GrpcServerOptions options;
506 options.rendezvous_mgr_func = NewRpcRendezvousMgr;
507 Status s = ret->Init(options);
508 if (!s.ok()) {
509 LOG(ERROR) << s;
510 return s;
511 }
512 *out_server = std::move(ret);
513 return Status::OK();
514 }
515
516 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)517 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
518 std::unique_ptr<GrpcServer>* out_server) {
519 std::unique_ptr<GrpcServer> ret(
520 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
521 GrpcServerOptions options;
522 options.rendezvous_mgr_func = NewRpcRendezvousMgr;
523 Status s = ret->Init(options);
524 if (!s.ok()) {
525 LOG(ERROR) << s;
526 return s;
527 }
528 *out_server = std::move(ret);
529 return Status::OK();
530 }
531
532 namespace {
533
534 class GrpcServerFactory : public ServerFactory {
535 public:
AcceptsOptions(const ServerDef & server_def)536 bool AcceptsOptions(const ServerDef& server_def) override {
537 return server_def.protocol() == "grpc";
538 }
539
NewServer(const ServerDef & server_def,std::unique_ptr<ServerInterface> * out_server)540 Status NewServer(const ServerDef& server_def,
541 std::unique_ptr<ServerInterface>* out_server) override {
542 return GrpcServer::Create(server_def, Env::Default(), out_server);
543 }
544 };
545
546 // Registers a `ServerFactory` for `GrpcServer` instances.
547 class GrpcServerRegistrar {
548 public:
GrpcServerRegistrar()549 GrpcServerRegistrar() {
550 ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
551 }
552 };
553 static GrpcServerRegistrar registrar;
554
555 } // namespace
556 } // namespace tensorflow
557