1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "grpcpp/server_builder.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
30 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
31 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
32 #include "tensorflow/core/distributed_runtime/local_master.h"
33 #include "tensorflow/core/distributed_runtime/master.h"
34 #include "tensorflow/core/distributed_runtime/master_env.h"
35 #include "tensorflow/core/distributed_runtime/master_session.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
42 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
43 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
44 #include "tensorflow/core/distributed_runtime/server_lib.h"
45 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
46 #include "tensorflow/core/distributed_runtime/worker_env.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/nccl/collective_communicator.h"
51 #include "tensorflow/core/platform/cpu_info.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mem.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/util/env_var.h"
58
59 namespace tensorflow {
60
61 namespace {
62
63 // Define an option subclass in order to disable SO_REUSEPORT for the
64 // server socket.
65 class NoReusePortOption : public ::grpc::ServerBuilderOption {
66 public:
UpdateArguments(::grpc::ChannelArguments * args)67 void UpdateArguments(::grpc::ChannelArguments* args) override {
68 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
69 }
70
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)71 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
72 plugins) override {}
73 };
74
75 // Define an option subclass in order to enable SO_REUSEPORT for the
76 // server socket.
77 class ReusePortOption : public ::grpc::ServerBuilderOption {
78 public:
UpdateArguments(::grpc::ChannelArguments * args)79 void UpdateArguments(::grpc::ChannelArguments* args) override {
80 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 1);
81 }
82
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)83 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
84 plugins) override {}
85 };
86
87 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)88 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
89 return new RpcRendezvousMgr(env);
90 }
91
92 } // namespace
93
GrpcServer(const ServerDef & server_def,Env * env)94 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
95 : env_(env), state_(NEW), server_def_(server_def) {}
96
~GrpcServer()97 GrpcServer::~GrpcServer() {
98 TF_CHECK_OK(Stop());
99 TF_CHECK_OK(Join());
100
101 delete master_service_;
102 delete worker_service_;
103 delete eager_service_;
104
105 // TODO(mrry): Refactor the *Env classes so that it is less fiddly
106 // to destroy them.
107
108 // Shut down all outstanding rendezvous.
109 delete worker_env_.rendezvous_mgr;
110
111 // We must delete graph_mgr before device_mgr, due to shared
112 // ownership of OpKernels in the executors. (The graph_mgr will
113 // free all stateless OpKernels, and pass over borrowed stateful
114 // OpKernels, which are also held in their respective devices'
115 // OpSegments.)
116 if (worker_env_.session_mgr != nullptr) {
117 delete worker_env_.session_mgr; // Deletes graph_mgr's.
118 }
119
120 // Do not delete (as these are not owned by the server):
121 // - master_env_.env
122 // - worker_env_.env
123 // - worker_env_.compute_pool
124 }
125
MaybeMutateBuilder(::grpc::ServerBuilder * builder)126 void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
127
128 // Look up the requested host name and port for this task in `server_def`.
GetHostAndPort(const ServerDef & server_def,string * host_name,int * port) const129 Status GrpcServer::GetHostAndPort(const ServerDef& server_def,
130 string* host_name, int* port) const {
131 *port = -1;
132 *host_name = "localhost";
133 for (const auto& job : server_def.cluster().job()) {
134 if (job.name() == server_def.job_name()) {
135 auto iter = job.tasks().find(server_def.task_index());
136 if (iter == job.tasks().end()) {
137 return errors::Internal("Task ", server_def.task_index(),
138 " was not defined in job \"",
139 server_def.job_name(), "\"");
140 }
141
142 if (server_def.port() != 0) {
143 *port = server_def.port();
144 } else {
145 auto colon_index = iter->second.find_last_of(':');
146 if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
147 port)) {
148 return errors::InvalidArgument(
149 "Could not parse port for local server from \"", iter->second,
150 "\".");
151 }
152
153 if (colon_index != string::npos &&
154 !iter->second.substr(0, colon_index).empty()) {
155 *host_name = iter->second.substr(0, colon_index);
156 }
157 }
158 break;
159 }
160 }
161 if (*port == -1) {
162 return errors::Internal("Job \"", server_def.job_name(),
163 "\" was not defined in cluster");
164 }
165
166 return Status::OK();
167 }
168
Init(const GrpcServerOptions & opts)169 Status GrpcServer::Init(const GrpcServerOptions& opts) {
170 mutex_lock l(mu_);
171 CHECK_EQ(state_, NEW);
172 master_env_.env = env_;
173 worker_env_.env = env_;
174
175 // Check parameters before DeviceFactory::AddDevices,
176 // otherwise if 'task_index=-1' the program will abort.
177
178 int requested_port;
179 TF_RETURN_IF_ERROR(GetHostAndPort(server_def_, &host_name_, &requested_port));
180
181 SessionOptions sess_opts;
182 ConfigProto config = server_def_.default_session_config();
183 sess_opts.config = config;
184
185 // Configure shared devices between master and worker.
186 string name_prefix =
187 strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
188 "/task:", server_def_.task_index());
189 if (opts.local_device_mgr == nullptr) {
190 std::vector<std::unique_ptr<Device>> devices;
191 TF_RETURN_IF_ERROR(
192 DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
193 worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
194 owned_device_manager_.reset(worker_env_.device_mgr);
195 } else {
196 worker_env_.device_mgr = opts.local_device_mgr;
197 owned_device_manager_.reset(nullptr);
198 }
199 worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
200 master_env_.local_devices = worker_env_.device_mgr->ListDevices();
201 worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
202 ? new RpcRendezvousMgr(&worker_env_)
203 : opts.rendezvous_mgr_func(&worker_env_);
204 string unused;
205 string default_worker_name;
206 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
207 &default_worker_name, &unused)) {
208 return errors::Internal("Could not parse worker name.");
209 }
210
211 // N.B. The order of initialization here is intricate, because we
212 // wish to allow `requested_port == 0` (for choosing any port,
213 // mostly for testing). Therefore, the construction of the channel
214 // and worker caches depends on `bound_port_`, which is not set
215 // until we call `builder.BuildAndStart()`. We must create the
216 // service objects before calling `builder.BuildAndStart()`, but
217 // `master_env_` and `worker_env_` are only partially
218 // configured. However, this is not dangerous, because we do not
219 // start serving requests until `this->Start()` is called, which
220 // happens after this method returns.
221 //
222 // TODO(mrry): Provide a general mechanism for dynamically setting
223 // the identities of tasks in the worker pool after the service is
224 // running.
225 ::grpc::ServerBuilder builder;
226 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
227 GetServerCredentials(server_def_), &bound_port_);
228 builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
229
230 bool reuse_port = false;
231 const Status status =
232 ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port);
233 if (!status.ok()) {
234 LOG(ERROR) << status.error_message();
235 }
236 auto server_build_option =
237 reuse_port
238 ? std::unique_ptr<::grpc::ServerBuilderOption>(new ReusePortOption)
239 : std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption);
240 builder.SetOption(std::move(server_build_option));
241
242 // Allow subclasses to specify more args to pass to the gRPC server.
243 MaybeMutateBuilder(&builder);
244 master_impl_ = CreateMaster(&master_env_);
245 master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
246 worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
247 : NewGrpcWorker(&worker_env_, config);
248 worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
249 opts.worker_service_options)
250 .release();
251 eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
252
253 profiler_service_ = profiler::CreateProfilerService();
254 builder.RegisterService(profiler_service_.get());
255
256 // extra service:
257 if (opts.service_func != nullptr) {
258 opts.service_func(&worker_env_, &builder);
259 }
260 server_ = builder.BuildAndStart();
261
262 if (!server_) {
263 return errors::Unknown("Could not start gRPC server");
264 }
265 // Create the execution environment for the GRPC workers cache.
266 grpc_worker_env_.reset(CreateGrpcWorkerEnv());
267
268 WorkerCacheInterface* worker_cache;
269 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
270 TF_RETURN_IF_ERROR(
271 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
272 CHECK_NE(nullptr, worker_cache);
273
274 if (opts.collective_mgr_func) {
275 worker_env_.collective_executor_mgr.reset(
276 opts.collective_mgr_func(config, &worker_env_, worker_cache));
277 if (worker_env_.collective_executor_mgr == nullptr) {
278 return errors::Internal(
279 "collective_mgr_func did not return CollectiveExecutorMgr");
280 }
281 } else {
282 std::unique_ptr<DeviceResolverDistributed> dev_resolver(
283 new DeviceResolverDistributed(worker_env_.device_mgr));
284 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
285 new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
286 dev_resolver.get(), worker_cache,
287 default_worker_name));
288 worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
289 config, worker_env_.device_mgr, std::move(dev_resolver),
290 std::move(param_resolver), MaybeCreateNcclCommunicator(), worker_cache,
291 default_worker_name));
292 }
293
294 // Set up worker environment.
295 worker_env_.session_mgr = new SessionMgr(
296 &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
297 std::unique_ptr<WorkerCacheInterface>(worker_cache),
298 [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
299 WorkerCacheFactoryOptions options(server_def);
300 return WorkerCacheFactory(options, worker_cache);
301 });
302 worker_env_.compute_pool = ComputePool(sess_opts);
303
304 // Finish setting up master environment.
305 master_env_.ops = OpRegistry::Global();
306 master_env_.worker_cache = worker_cache;
307 master_env_.collective_executor_mgr =
308 worker_env_.collective_executor_mgr.get();
309 StatsPublisherFactory stats_factory = opts.stats_factory;
310 master_env_.master_session_factory =
311 [config, stats_factory](
312 SessionOptions options, const MasterEnv* env,
313 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
314 std::unique_ptr<WorkerCacheInterface> worker_cache,
315 std::unique_ptr<DeviceSet> device_set,
316 std::vector<string> filtered_worker_list) {
317 options.config.MergeFrom(config);
318 return new MasterSession(options, env, std::move(remote_devs),
319 std::move(worker_cache), std::move(device_set),
320 std::move(filtered_worker_list),
321 stats_factory);
322 };
323 master_env_.worker_cache_factory =
324 [this](const WorkerCacheFactoryOptions& options,
325 WorkerCacheInterface** worker_cache) {
326 return WorkerCacheFactory(options, worker_cache);
327 };
328
329 // Provide direct access to the master from in-process clients.
330 LocalMaster::Register(target(), master_impl_.get(),
331 config.operation_timeout_in_ms());
332
333 return Status::OK();
334 }
335
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)336 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
337 GrpcChannelSpec* channel_spec) {
338 for (const auto& job : options.cluster_def->job()) {
339 std::map<int, string> host_ports;
340 for (const auto& task : job.tasks()) {
341 string& host_port = host_ports[task.first];
342 if (!host_port.empty()) {
343 return errors::InvalidArgument("JobDef for job \"", job.name(),
344 "\" specified two addresses for task \"",
345 task.first, "\": ", host_port, " and ",
346 task.second);
347 }
348 if (job.name() == *options.job_name && task.first == options.task_index) {
349 host_port = strings::StrCat(host_name_, ":", bound_port_);
350 } else {
351 host_port = task.second;
352 }
353 }
354 TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
355 }
356 return Status::OK();
357 }
358
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)359 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
360 WorkerCacheInterface** worker_cache) {
361 if (options.job_name == nullptr || options.job_name->empty()) {
362 Status s = errors::InvalidArgument(
363 "The master (current machine) is not included in the provided "
364 "cluster_def. ",
365 options.cluster_def->DebugString());
366 LOG(WARNING) << s;
367 return s;
368 }
369
370 GrpcChannelSpec channel_spec;
371 TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
372
373 std::shared_ptr<GrpcChannelCache> channel_cache(
374 NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
375
376 string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
377 "/task:", options.task_index);
378
379 const string host_port = channel_cache->TranslateTask(name_prefix);
380 int requested_port;
381
382 auto colon_index = host_port.find_last_of(':');
383 if (!strings::safe_strto32(host_port.substr(colon_index + 1),
384 &requested_port)) {
385 return errors::Internal("Could not parse port for local server from \"",
386 host_port, "\".");
387 }
388 if (requested_port != bound_port_) {
389 return errors::InvalidArgument("Requested port ", requested_port,
390 " differs from expected port ", bound_port_);
391 }
392 *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
393 channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
394 return Status::OK();
395 }
396
Start()397 Status GrpcServer::Start() {
398 mutex_lock l(mu_);
399 switch (state_) {
400 case NEW: {
401 master_thread_.reset(
402 env_->StartThread(ThreadOptions(), "TF_master_service",
403 [this] { master_service_->HandleRPCsLoop(); }));
404 worker_thread_.reset(
405 env_->StartThread(ThreadOptions(), "TF_worker_service",
406 [this] { worker_service_->HandleRPCsLoop(); }));
407 eager_thread_.reset(
408 env_->StartThread(ThreadOptions(), "TF_eager_service",
409 [this] { eager_service_->HandleRPCsLoop(); }));
410 state_ = STARTED;
411 LOG(INFO) << "Started server with target: " << target();
412 return Status::OK();
413 }
414 case STARTED:
415 LOG(INFO) << "Server already started (target: " << target() << ")";
416 return Status::OK();
417 case STOPPED:
418 return errors::FailedPrecondition("Server has stopped.");
419 default:
420 LOG(FATAL);
421 }
422 }
423
AddMasterEagerContextToEagerService(const tensorflow::uint64 context_id,tensorflow::EagerContext * context)424 Status GrpcServer::AddMasterEagerContextToEagerService(
425 const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
426 auto* eager_service =
427 static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
428 return eager_service->CreateMasterContext(context_id, context);
429 }
430
UpdateServerDef(const ServerDef & server_def)431 Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
432 mutex_lock l(mu_);
433 server_def_ = server_def;
434 WorkerCacheInterface* worker_cache;
435 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
436 TF_RETURN_IF_ERROR(
437 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
438 if (worker_cache == nullptr) {
439 return errors::InvalidArgument(
440 "Failed to build worker cache with the provided server def.");
441 }
442 // Transfer ownership of worker_cache to worker_env_.session_mgr.
443 worker_env_.session_mgr->ResetDefaultWorkerCache(worker_cache);
444
445 string default_worker_name;
446 string unused;
447 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
448 &default_worker_name, &unused)) {
449 return errors::Internal("Could not parse worker name.");
450 }
451 std::unique_ptr<DeviceResolverDistributed> dev_resolver(
452 new DeviceResolverDistributed(worker_env_.device_mgr));
453 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
454 new CollectiveParamResolverDistributed(
455 server_def_.default_session_config(), worker_env_.device_mgr,
456 dev_resolver.get(), worker_cache, default_worker_name));
457 worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
458 server_def_.default_session_config(), worker_env_.device_mgr,
459 std::move(dev_resolver), std::move(param_resolver),
460 MaybeCreateNcclCommunicator(), worker_cache, default_worker_name));
461
462 master_env_.worker_cache = worker_cache;
463 master_env_.collective_executor_mgr =
464 worker_env_.collective_executor_mgr.get();
465 return Status::OK();
466 }
467
Stop()468 Status GrpcServer::Stop() {
469 mutex_lock l(mu_);
470 switch (state_) {
471 case NEW:
472 state_ = STOPPED;
473 return Status::OK();
474 case STARTED:
475 return errors::Unimplemented(
476 "Clean shutdown is not currently implemented");
477 case STOPPED:
478 LOG(INFO) << "Server already stopped (target: " << target() << ")";
479 return Status::OK();
480 default:
481 LOG(FATAL);
482 }
483 }
484
Join()485 Status GrpcServer::Join() {
486 mutex_lock l(mu_);
487 switch (state_) {
488 case NEW:
489 // Prevent the server from being started subsequently.
490 state_ = STOPPED;
491 return Status::OK();
492 case STARTED:
493 case STOPPED:
494 master_thread_.reset();
495 worker_thread_.reset();
496 eager_thread_.reset();
497 return Status::OK();
498 default:
499 LOG(FATAL);
500 }
501 }
502
target() const503 const string GrpcServer::target() const {
504 return strings::StrCat("grpc://", host_name_, ":", bound_port_);
505 }
506
GetServerCredentials(const ServerDef & server_def) const507 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
508 const ServerDef& server_def) const {
509 return ::grpc::InsecureServerCredentials();
510 }
511
GetChannelCreationFunction() const512 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
513 // We can do this because SparseGrpcChannelCache is robust to nullptr being
514 // returned by the channel creation function
515 return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
516 }
517
CreateMaster(MasterEnv * master_env)518 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
519 return std::unique_ptr<Master>(new Master(master_env, 0.0));
520 }
521
522 /* static */
Create(const ServerDef & server_def,Env * env,const DeviceMgr * local_device_mgr,std::unique_ptr<ServerInterface> * out_server)523 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
524 const DeviceMgr* local_device_mgr,
525 std::unique_ptr<ServerInterface>* out_server) {
526 std::unique_ptr<GrpcServer> ret(
527 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
528 GrpcServerOptions options;
529 options.rendezvous_mgr_func = NewRpcRendezvousMgr;
530 options.local_device_mgr = local_device_mgr;
531 Status s = ret->Init(options);
532 if (!s.ok()) {
533 LOG(ERROR) << s;
534 return s;
535 }
536 *out_server = std::move(ret);
537 return Status::OK();
538 }
539
540 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)541 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
542 std::unique_ptr<ServerInterface>* out_server) {
543 return Create(server_def, env, nullptr, out_server);
544 }
545
546 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)547 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
548 std::unique_ptr<GrpcServer>* out_server) {
549 std::unique_ptr<ServerInterface> server;
550 Status s = Create(server_def, env, nullptr, &server);
551 if (!s.ok()) {
552 return s;
553 }
554 out_server->reset(dynamic_cast<GrpcServer*>(server.release()));
555 return Status::OK();
556 }
557
558 namespace {
559
560 class GrpcServerFactory : public ServerFactory {
561 public:
AcceptsOptions(const ServerDef & server_def)562 bool AcceptsOptions(const ServerDef& server_def) override {
563 return server_def.protocol() == "grpc";
564 }
565
NewServer(const ServerDef & server_def,const Options & options,std::unique_ptr<ServerInterface> * out_server)566 Status NewServer(const ServerDef& server_def, const Options& options,
567 std::unique_ptr<ServerInterface>* out_server) override {
568 return GrpcServer::Create(server_def, Env::Default(),
569 options.local_device_mgr, out_server);
570 }
571 };
572
573 // Registers a `ServerFactory` for `GrpcServer` instances.
574 class GrpcServerRegistrar {
575 public:
GrpcServerRegistrar()576 GrpcServerRegistrar() {
577 ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
578 }
579 };
580 static GrpcServerRegistrar registrar;
581
582 } // namespace
583 } // namespace tensorflow
584