1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "grpcpp/server_builder.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
30 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
31 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
32 #include "tensorflow/core/distributed_runtime/local_master.h"
33 #include "tensorflow/core/distributed_runtime/master.h"
34 #include "tensorflow/core/distributed_runtime/master_env.h"
35 #include "tensorflow/core/distributed_runtime/master_session.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
42 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
43 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
44 #include "tensorflow/core/distributed_runtime/server_lib.h"
45 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
46 #include "tensorflow/core/distributed_runtime/worker_env.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/nccl/collective_communicator.h"
51 #include "tensorflow/core/platform/cpu_info.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mem.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/util/env_var.h"
58
59 namespace tensorflow {
60
61 namespace {
62
63 // Define an option subclass in order to disable SO_REUSEPORT for the
64 // server socket.
65 class NoReusePortOption : public ::grpc::ServerBuilderOption {
66 public:
UpdateArguments(::grpc::ChannelArguments * args)67 void UpdateArguments(::grpc::ChannelArguments* args) override {
68 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
69 }
70
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)71 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
72 plugins) override {}
73 };
74
75 // Define an option subclass in order to enable SO_REUSEPORT for the
76 // server socket.
77 class ReusePortOption : public ::grpc::ServerBuilderOption {
78 public:
UpdateArguments(::grpc::ChannelArguments * args)79 void UpdateArguments(::grpc::ChannelArguments* args) override {
80 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 1);
81 }
82
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)83 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
84 plugins) override {}
85 };
86
87 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)88 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
89 return new RpcRendezvousMgr(env);
90 }
91
92 } // namespace
93
GrpcServer(const ServerDef & server_def,Env * env)94 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
95 : env_(env), state_(NEW), server_def_(server_def) {}
96
~GrpcServer()97 GrpcServer::~GrpcServer() {
98 TF_CHECK_OK(Stop());
99 TF_CHECK_OK(Join());
100
101 delete master_service_;
102 delete worker_service_;
103 delete eager_service_;
104
105 for (auto& kv : extra_services_) {
106 AsyncServiceInterface* service = kv.second;
107 delete service;
108 }
109
110 // TODO(mrry): Refactor the *Env classes so that it is less fiddly
111 // to destroy them.
112
113 // Shut down all outstanding rendezvous.
114 delete worker_env_.rendezvous_mgr;
115
116 // We must delete graph_mgr before device_mgr, due to shared
117 // ownership of OpKernels in the executors. (The graph_mgr will
118 // free all stateless OpKernels, and pass over borrowed stateful
119 // OpKernels, which are also held in their respective devices'
120 // OpSegments.)
121 if (worker_env_.session_mgr != nullptr) {
122 delete worker_env_.session_mgr; // Deletes graph_mgr's.
123 }
124
125 // Do not delete (as these are not owned by the server):
126 // - master_env_.env
127 // - worker_env_.env
128 // - worker_env_.compute_pool
129 }
130
131 // Look up the requested host name and port for this task in `server_def`.
GetHostAndPort(const ServerDef & server_def,string * host_name,int * port) const132 Status GrpcServer::GetHostAndPort(const ServerDef& server_def,
133 string* host_name, int* port) const {
134 *port = -1;
135 *host_name = "localhost";
136 for (const auto& job : server_def.cluster().job()) {
137 if (job.name() == server_def.job_name()) {
138 auto iter = job.tasks().find(server_def.task_index());
139 if (iter == job.tasks().end()) {
140 return errors::Internal("Task ", server_def.task_index(),
141 " was not defined in job \"",
142 server_def.job_name(), "\"");
143 }
144
145 if (server_def.port() != 0) {
146 *port = server_def.port();
147 } else {
148 auto colon_index = iter->second.find_last_of(':');
149 if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
150 port)) {
151 return errors::InvalidArgument(
152 "Could not parse port for local server from \"", iter->second,
153 "\".");
154 }
155
156 if (colon_index != string::npos &&
157 !iter->second.substr(0, colon_index).empty()) {
158 *host_name = iter->second.substr(0, colon_index);
159 }
160 }
161 break;
162 }
163 }
164 if (*port == -1) {
165 return errors::Internal("Job \"", server_def.job_name(),
166 "\" was not defined in cluster");
167 }
168
169 return Status::OK();
170 }
171
Init(const GrpcServerOptions & opts)172 Status GrpcServer::Init(const GrpcServerOptions& opts) {
173 mutex_lock l(mu_);
174 CHECK_EQ(state_, NEW);
175 master_env_.env = env_;
176 worker_env_.env = env_;
177
178 // Check parameters before DeviceFactory::AddDevices,
179 // otherwise if 'task_index=-1' the program will abort.
180
181 int requested_port;
182 TF_RETURN_IF_ERROR(GetHostAndPort(server_def_, &host_name_, &requested_port));
183
184 SessionOptions sess_opts;
185 VLOG(3) << "Grpc Server Init Definition: " << server_def_.DebugString();
186 ConfigProto config = server_def_.default_session_config();
187 sess_opts.config = config;
188
189 // Configure shared devices between master and worker.
190 string name_prefix =
191 strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
192 "/task:", server_def_.task_index());
193 if (opts.local_device_mgr == nullptr) {
194 std::vector<std::unique_ptr<Device>> devices;
195 TF_RETURN_IF_ERROR(
196 DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
197 worker_env_.device_mgr = new DynamicDeviceMgr(std::move(devices));
198 owned_device_manager_.reset(worker_env_.device_mgr);
199 } else {
200 worker_env_.device_mgr = opts.local_device_mgr;
201 owned_device_manager_.reset(nullptr);
202 }
203 worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
204 master_env_.local_devices = worker_env_.device_mgr->ListDevices();
205 worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
206 ? new RpcRendezvousMgr(&worker_env_)
207 : opts.rendezvous_mgr_func(&worker_env_);
208 string unused;
209 string default_worker_name;
210 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
211 &default_worker_name, &unused)) {
212 return errors::Internal("Could not parse worker name.");
213 }
214
215 // N.B. The order of initialization here is intricate, because we
216 // wish to allow `requested_port == 0` (for choosing any port,
217 // mostly for testing). Therefore, the construction of the channel
218 // and worker caches depends on `bound_port_`, which is not set
219 // until we call `builder.BuildAndStart()`. We must create the
220 // service objects before calling `builder.BuildAndStart()`, but
221 // `master_env_` and `worker_env_` are only partially
222 // configured. However, this is not dangerous, because we do not
223 // start serving requests until `this->Start()` is called, which
224 // happens after this method returns.
225 //
226 // TODO(mrry): Provide a general mechanism for dynamically setting
227 // the identities of tasks in the worker pool after the service is
228 // running.
229 ::grpc::ServerBuilder builder;
230 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
231 GetServerCredentials(server_def_), &bound_port_);
232 builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
233
234 bool reuse_port = false;
235 const Status status =
236 ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port);
237 if (!status.ok()) {
238 LOG(ERROR) << status.error_message();
239 }
240 auto server_build_option =
241 reuse_port
242 ? std::unique_ptr<::grpc::ServerBuilderOption>(new ReusePortOption)
243 : std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption);
244 builder.SetOption(std::move(server_build_option));
245
246 // Allow subclasses to specify more args to pass to the gRPC server.
247 MaybeMutateBuilder(&builder, requested_port);
248 master_impl_ = CreateMaster(&master_env_);
249 master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
250 worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
251 : NewGrpcWorker(&worker_env_, config);
252 worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
253 opts.worker_service_options)
254 .release();
255 eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
256
257 profiler_service_ = profiler::CreateProfilerService();
258 builder.RegisterService(profiler_service_.get());
259
260 // Add any extra services to be started.
261 extra_services_ = ExtraServices(&builder);
262
263 // extra service:
264 if (opts.service_func != nullptr) {
265 opts.service_func(&worker_env_, &builder);
266 }
267 server_ = builder.BuildAndStart();
268
269 if (!server_) {
270 return errors::Unknown("Could not start gRPC server");
271 }
272 // Create the execution environment for the GRPC workers cache.
273 grpc_worker_env_.reset(CreateGrpcWorkerEnv());
274
275 WorkerCacheInterface* worker_cache;
276 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
277 TF_RETURN_IF_ERROR(
278 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
279 CHECK_NE(nullptr, worker_cache);
280
281 if (opts.collective_mgr_func) {
282 worker_env_.collective_executor_mgr.reset(
283 opts.collective_mgr_func(config, &worker_env_, worker_cache));
284 if (worker_env_.collective_executor_mgr == nullptr) {
285 return errors::Internal(
286 "collective_mgr_func did not return CollectiveExecutorMgr");
287 }
288 } else {
289 worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
290 config, worker_env_.device_mgr, MaybeCreateNcclCommunicator(config),
291 worker_cache, default_worker_name);
292 }
293
294 // Set up worker environment.
295 worker_env_.session_mgr = new SessionMgr(
296 &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
297 std::unique_ptr<WorkerCacheInterface>(worker_cache),
298 [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
299 WorkerCacheFactoryOptions options(server_def);
300 return WorkerCacheFactory(options, worker_cache);
301 });
302 worker_env_.compute_pool = ComputePool(sess_opts);
303
304 // Finish setting up master environment.
305 master_env_.ops = OpRegistry::Global();
306 master_env_.worker_cache = worker_cache;
307 master_env_.collective_executor_mgr =
308 worker_env_.collective_executor_mgr.get();
309 StatsPublisherFactory stats_factory = opts.stats_factory;
310 master_env_.master_session_factory =
311 [config, stats_factory](
312 SessionOptions options, const MasterEnv* env,
313 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
314 std::unique_ptr<WorkerCacheInterface> worker_cache,
315 std::unique_ptr<DeviceSet> device_set,
316 std::vector<string> filtered_worker_list) {
317 options.config.MergeFrom(config);
318 return new MasterSession(options, env, std::move(remote_devs),
319 std::move(worker_cache), std::move(device_set),
320 std::move(filtered_worker_list),
321 stats_factory);
322 };
323 master_env_.worker_cache_factory =
324 [this](const WorkerCacheFactoryOptions& options,
325 WorkerCacheInterface** worker_cache) {
326 return WorkerCacheFactory(options, worker_cache);
327 };
328
329 // Provide direct access to the master from in-process clients.
330 LocalMaster::Register(target(), master_impl_.get(),
331 config.operation_timeout_in_ms());
332
333 return Status::OK();
334 }
335
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)336 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
337 GrpcChannelSpec* channel_spec) {
338 for (const auto& job : options.cluster_def->job()) {
339 std::map<int, string> host_ports;
340 for (const auto& task : job.tasks()) {
341 string& host_port = host_ports[task.first];
342 if (!host_port.empty()) {
343 return errors::InvalidArgument("JobDef for job \"", job.name(),
344 "\" specified two addresses for task \"",
345 task.first, "\": ", host_port, " and ",
346 task.second);
347 }
348 if (job.name() == *options.job_name && task.first == options.task_index) {
349 host_port = strings::StrCat(host_name_, ":", bound_port_);
350 } else {
351 host_port = task.second;
352 }
353 }
354 TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
355 }
356 return Status::OK();
357 }
358
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)359 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
360 WorkerCacheInterface** worker_cache) {
361 if (options.job_name == nullptr || options.job_name->empty()) {
362 Status s = errors::InvalidArgument(
363 "The master (current machine) is not included in the provided "
364 "cluster_def. ",
365 options.cluster_def->DebugString());
366 LOG(WARNING) << s;
367 return s;
368 }
369
370 GrpcChannelSpec channel_spec;
371 TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
372
373 if (options.rpc_options == nullptr) {
374 return errors::InvalidArgument(
375 "rpc_options not set in WorkerCacheFactoryOptions");
376 }
377 std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
378 channel_spec, GetChannelCreationFunction(), *options.rpc_options));
379
380 string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
381 "/task:", options.task_index);
382
383 const string host_port = channel_cache->TranslateTask(name_prefix);
384 int requested_port;
385
386 auto colon_index = host_port.find_last_of(':');
387 if (!strings::safe_strto32(host_port.substr(colon_index + 1),
388 &requested_port)) {
389 return errors::Internal("Could not parse port for local server from \"",
390 host_port, "\".");
391 }
392 if (requested_port != bound_port_) {
393 return errors::InvalidArgument("Requested port ", requested_port,
394 " differs from expected port ", bound_port_);
395 }
396 *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
397 channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
398 return Status::OK();
399 }
400
Start()401 Status GrpcServer::Start() {
402 mutex_lock l(mu_);
403 switch (state_) {
404 case NEW: {
405 master_thread_.reset(
406 env_->StartThread(ThreadOptions(), "TF_master_service",
407 [this] { master_service_->HandleRPCsLoop(); }));
408 worker_thread_.reset(
409 env_->StartThread(ThreadOptions(), "TF_worker_service",
410 [this] { worker_service_->HandleRPCsLoop(); }));
411 eager_thread_.reset(
412 env_->StartThread(ThreadOptions(), "TF_eager_service",
413 [this] { eager_service_->HandleRPCsLoop(); }));
414
415 for (const auto& kv : extra_services_) {
416 const std::string& service_name = kv.first;
417 AsyncServiceInterface* service = kv.second;
418 std::unique_ptr<Thread> extra_service_thread;
419 extra_service_thread.reset(env_->StartThread(
420 ThreadOptions(), service_name,
421 [service = service] { service->HandleRPCsLoop(); }));
422 extra_service_threads_.push_back(std::move(extra_service_thread));
423 VLOG(3) << "Started extra service: " << service_name;
424 }
425
426 state_ = STARTED;
427 LOG(INFO) << "Started server with target: " << target();
428 return Status::OK();
429 }
430 case STARTED:
431 LOG(INFO) << "Server already started (target: " << target() << ")";
432 return Status::OK();
433 case STOPPED:
434 return errors::FailedPrecondition("Server has stopped.");
435 default:
436 LOG(FATAL);
437 }
438 }
439
AddMasterEagerContextToEagerService(const tensorflow::uint64 context_id,tensorflow::EagerContext * context)440 Status GrpcServer::AddMasterEagerContextToEagerService(
441 const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
442 auto* eager_service =
443 static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
444 return eager_service->CreateMasterContext(context_id, context);
445 }
446
UpdateServerDef(const ServerDef & server_def)447 Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
448 mutex_lock l(mu_);
449 server_def_ = server_def;
450 WorkerCacheInterface* worker_cache;
451 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
452 TF_RETURN_IF_ERROR(
453 WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
454 if (worker_cache == nullptr) {
455 return errors::InvalidArgument(
456 "Failed to build worker cache with the provided server def.");
457 }
458 // Transfer ownership of worker_cache to worker_env_.session_mgr.
459 worker_env_.session_mgr->ResetDefaultWorkerCache(worker_cache);
460
461 string default_worker_name;
462 string unused;
463 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
464 &default_worker_name, &unused)) {
465 return errors::Internal("Could not parse worker name.");
466 }
467 worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
468 server_def_.default_session_config(), worker_env_.device_mgr,
469 MaybeCreateNcclCommunicator(server_def_.default_session_config()),
470 worker_cache, default_worker_name);
471
472 master_env_.worker_cache = worker_cache;
473 master_env_.collective_executor_mgr =
474 worker_env_.collective_executor_mgr.get();
475 return Status::OK();
476 }
477
478 // TODO(haoyuzhang): Remove this method once we have a mechanism to directly set
479 // field inside the RPC coordination service handler.
SetCoordinationServiceAgentInstance(CoordinationServiceAgent * agent)480 Status GrpcServer::SetCoordinationServiceAgentInstance(
481 CoordinationServiceAgent* agent) {
482 // No op, coordination service is not implemented in open source.
483 return Status::OK();
484 }
485
Stop()486 Status GrpcServer::Stop() {
487 mutex_lock l(mu_);
488 switch (state_) {
489 case NEW:
490 state_ = STOPPED;
491 return Status::OK();
492 case STARTED:
493 return errors::Unimplemented(
494 "Clean shutdown is not currently implemented");
495 case STOPPED:
496 LOG(INFO) << "Server already stopped (target: " << target() << ")";
497 return Status::OK();
498 default:
499 LOG(FATAL);
500 }
501 }
502
Join()503 Status GrpcServer::Join() {
504 mutex_lock l(mu_);
505 switch (state_) {
506 case NEW:
507 // Prevent the server from being started subsequently.
508 state_ = STOPPED;
509 return Status::OK();
510 case STARTED:
511 case STOPPED:
512 master_thread_.reset();
513 worker_thread_.reset();
514 eager_thread_.reset();
515 for (auto& thread : extra_service_threads_) {
516 thread.reset();
517 }
518 return Status::OK();
519 default:
520 LOG(FATAL);
521 }
522 }
523
target() const524 const string GrpcServer::target() const {
525 return strings::StrCat("grpc://", host_name_, ":", bound_port_);
526 }
527
GetServerCredentials(const ServerDef & server_def) const528 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
529 const ServerDef& server_def) const {
530 return ::grpc::InsecureServerCredentials();
531 }
532
GetChannelCreationFunction() const533 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
534 // We can do this because SparseGrpcChannelCache is robust to nullptr being
535 // returned by the channel creation function
536 return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
537 }
538
CreateMaster(MasterEnv * master_env)539 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
540 return std::unique_ptr<Master>(new Master(master_env, 0.0));
541 }
542
543 /* static */
Create(const ServerDef & server_def,Env * env,DeviceMgr * local_device_mgr,std::unique_ptr<ServerInterface> * out_server)544 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
545 DeviceMgr* local_device_mgr,
546 std::unique_ptr<ServerInterface>* out_server) {
547 std::unique_ptr<GrpcServer> ret(
548 new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
549 GrpcServerOptions options;
550 options.rendezvous_mgr_func = NewRpcRendezvousMgr;
551 options.local_device_mgr = local_device_mgr;
552 Status s = ret->Init(options);
553 if (!s.ok()) {
554 LOG(ERROR) << s;
555 return s;
556 }
557 *out_server = std::move(ret);
558 return Status::OK();
559 }
560
561 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)562 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
563 std::unique_ptr<ServerInterface>* out_server) {
564 return Create(server_def, env, nullptr, out_server);
565 }
566
567 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)568 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
569 std::unique_ptr<GrpcServer>* out_server) {
570 std::unique_ptr<ServerInterface> server;
571 Status s = Create(server_def, env, nullptr, &server);
572 if (!s.ok()) {
573 return s;
574 }
575 out_server->reset(dynamic_cast<GrpcServer*>(server.release()));
576 return Status::OK();
577 }
578
579 namespace {
580
581 class GrpcServerFactory : public ServerFactory {
582 public:
AcceptsOptions(const ServerDef & server_def)583 bool AcceptsOptions(const ServerDef& server_def) override {
584 return server_def.protocol() == "grpc";
585 }
586
NewServer(const ServerDef & server_def,const Options & options,std::unique_ptr<ServerInterface> * out_server)587 Status NewServer(const ServerDef& server_def, const Options& options,
588 std::unique_ptr<ServerInterface>* out_server) override {
589 return GrpcServer::Create(server_def, Env::Default(),
590 options.local_device_mgr, out_server);
591 }
592 };
593
594 // Registers a `ServerFactory` for `GrpcServer` instances.
595 class GrpcServerRegistrar {
596 public:
GrpcServerRegistrar()597 GrpcServerRegistrar() {
598 ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
599 }
600 };
601 static GrpcServerRegistrar registrar;
602
603 } // namespace
604 } // namespace tensorflow
605