• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17 
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "grpcpp/server_builder.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
30 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
31 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
32 #include "tensorflow/core/distributed_runtime/local_master.h"
33 #include "tensorflow/core/distributed_runtime/master.h"
34 #include "tensorflow/core/distributed_runtime/master_env.h"
35 #include "tensorflow/core/distributed_runtime/master_session.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
42 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
43 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
44 #include "tensorflow/core/distributed_runtime/server_lib.h"
45 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
46 #include "tensorflow/core/distributed_runtime/worker_env.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/nccl/collective_communicator.h"
51 #include "tensorflow/core/platform/cpu_info.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mem.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/util/env_var.h"
58 
59 namespace tensorflow {
60 
61 namespace {
62 
63 // Define an option subclass in order to disable SO_REUSEPORT for the
64 // server socket.
65 class NoReusePortOption : public ::grpc::ServerBuilderOption {
66  public:
UpdateArguments(::grpc::ChannelArguments * args)67   void UpdateArguments(::grpc::ChannelArguments* args) override {
68     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
69   }
70 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)71   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
72                          plugins) override {}
73 };
74 
75 // Define an option subclass in order to enable SO_REUSEPORT for the
76 // server socket.
77 class ReusePortOption : public ::grpc::ServerBuilderOption {
78  public:
UpdateArguments(::grpc::ChannelArguments * args)79   void UpdateArguments(::grpc::ChannelArguments* args) override {
80     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 1);
81   }
82 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)83   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
84                          plugins) override {}
85 };
86 
87 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)88 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
89   return new RpcRendezvousMgr(env);
90 }
91 
92 }  // namespace
93 
GrpcServer(const ServerDef & server_def,Env * env)94 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
95     : env_(env), state_(NEW), server_def_(server_def) {}
96 
~GrpcServer()97 GrpcServer::~GrpcServer() {
98   TF_CHECK_OK(Stop());
99   TF_CHECK_OK(Join());
100 
101   delete master_service_;
102   delete worker_service_;
103   delete eager_service_;
104 
105   // TODO(mrry): Refactor the *Env classes so that it is less fiddly
106   // to destroy them.
107 
108   // Shut down all outstanding rendezvous.
109   delete worker_env_.rendezvous_mgr;
110 
111   // We must delete graph_mgr before device_mgr, due to shared
112   // ownership of OpKernels in the executors. (The graph_mgr will
113   // free all stateless OpKernels, and pass over borrowed stateful
114   // OpKernels, which are also held in their respective devices'
115   // OpSegments.)
116   if (worker_env_.session_mgr != nullptr) {
117     delete worker_env_.session_mgr;  // Deletes graph_mgr's.
118   }
119 
120   // Do not delete (as these are not owned by the server):
121   // - master_env_.env
122   // - worker_env_.env
123   // - worker_env_.compute_pool
124 }
125 
MaybeMutateBuilder(::grpc::ServerBuilder * builder)126 void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
127 
128 // Look up the requested host name and port for this task in `server_def`.
GetHostAndPort(const ServerDef & server_def,string * host_name,int * port) const129 Status GrpcServer::GetHostAndPort(const ServerDef& server_def,
130                                   string* host_name, int* port) const {
131   *port = -1;
132   *host_name = "localhost";
133   for (const auto& job : server_def.cluster().job()) {
134     if (job.name() == server_def.job_name()) {
135       auto iter = job.tasks().find(server_def.task_index());
136       if (iter == job.tasks().end()) {
137         return errors::Internal("Task ", server_def.task_index(),
138                                 " was not defined in job \"",
139                                 server_def.job_name(), "\"");
140       }
141 
142       if (server_def.port() != 0) {
143         *port = server_def.port();
144       } else {
145         auto colon_index = iter->second.find_last_of(':');
146         if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
147                                    port)) {
148           return errors::InvalidArgument(
149               "Could not parse port for local server from \"", iter->second,
150               "\".");
151         }
152 
153         if (colon_index != string::npos &&
154             !iter->second.substr(0, colon_index).empty()) {
155           *host_name = iter->second.substr(0, colon_index);
156         }
157       }
158       break;
159     }
160   }
161   if (*port == -1) {
162     return errors::Internal("Job \"", server_def.job_name(),
163                             "\" was not defined in cluster");
164   }
165 
166   return Status::OK();
167 }
168 
Init(const GrpcServerOptions & opts)169 Status GrpcServer::Init(const GrpcServerOptions& opts) {
170   mutex_lock l(mu_);
171   CHECK_EQ(state_, NEW);
172   master_env_.env = env_;
173   worker_env_.env = env_;
174 
175   // Check parameters before DeviceFactory::AddDevices,
176   // otherwise if 'task_index=-1' the program will abort.
177 
178   int requested_port;
179   TF_RETURN_IF_ERROR(GetHostAndPort(server_def_, &host_name_, &requested_port));
180 
181   SessionOptions sess_opts;
182   ConfigProto config = server_def_.default_session_config();
183   sess_opts.config = config;
184 
185   // Configure shared devices between master and worker.
186   string name_prefix =
187       strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
188                       "/task:", server_def_.task_index());
189   if (opts.local_device_mgr == nullptr) {
190     std::vector<std::unique_ptr<Device>> devices;
191     TF_RETURN_IF_ERROR(
192         DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
193     worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
194     owned_device_manager_.reset(worker_env_.device_mgr);
195   } else {
196     worker_env_.device_mgr = opts.local_device_mgr;
197     owned_device_manager_.reset(nullptr);
198   }
199   worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
200   master_env_.local_devices = worker_env_.device_mgr->ListDevices();
201   worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
202                                    ? new RpcRendezvousMgr(&worker_env_)
203                                    : opts.rendezvous_mgr_func(&worker_env_);
204   string unused;
205   string default_worker_name;
206   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
207                                         &default_worker_name, &unused)) {
208     return errors::Internal("Could not parse worker name.");
209   }
210 
211   // N.B. The order of initialization here is intricate, because we
212   // wish to allow `requested_port == 0` (for choosing any port,
213   // mostly for testing). Therefore, the construction of the channel
214   // and worker caches depends on `bound_port_`, which is not set
215   // until we call `builder.BuildAndStart()`. We must create the
216   // service objects before calling `builder.BuildAndStart()`, but
217   // `master_env_` and `worker_env_` are only partially
218   // configured. However, this is not dangerous, because we do not
219   // start serving requests until `this->Start()` is called, which
220   // happens after this method returns.
221   //
222   // TODO(mrry): Provide a general mechanism for dynamically setting
223   // the identities of tasks in the worker pool after the service is
224   // running.
225   ::grpc::ServerBuilder builder;
226   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
227                            GetServerCredentials(server_def_), &bound_port_);
228   builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
229 
230   bool reuse_port = false;
231   const Status status =
232       ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port);
233   if (!status.ok()) {
234     LOG(ERROR) << status.error_message();
235   }
236   auto server_build_option =
237       reuse_port
238           ? std::unique_ptr<::grpc::ServerBuilderOption>(new ReusePortOption)
239           : std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption);
240   builder.SetOption(std::move(server_build_option));
241 
242   // Allow subclasses to specify more args to pass to the gRPC server.
243   MaybeMutateBuilder(&builder);
244   master_impl_ = CreateMaster(&master_env_);
245   master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
246   worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
247                                   : NewGrpcWorker(&worker_env_, config);
248   worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
249                                          opts.worker_service_options)
250                         .release();
251   eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
252 
253   profiler_service_ = profiler::CreateProfilerService();
254   builder.RegisterService(profiler_service_.get());
255 
256   // extra service:
257   if (opts.service_func != nullptr) {
258     opts.service_func(&worker_env_, &builder);
259   }
260   server_ = builder.BuildAndStart();
261 
262   if (!server_) {
263     return errors::Unknown("Could not start gRPC server");
264   }
265   // Create the execution environment for the GRPC workers cache.
266   grpc_worker_env_.reset(CreateGrpcWorkerEnv());
267 
268   WorkerCacheInterface* worker_cache;
269   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
270   TF_RETURN_IF_ERROR(
271       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
272   CHECK_NE(nullptr, worker_cache);
273 
274   if (opts.collective_mgr_func) {
275     worker_env_.collective_executor_mgr.reset(
276         opts.collective_mgr_func(config, &worker_env_, worker_cache));
277     if (worker_env_.collective_executor_mgr == nullptr) {
278       return errors::Internal(
279           "collective_mgr_func did not return CollectiveExecutorMgr");
280     }
281   } else {
282     std::unique_ptr<DeviceResolverDistributed> dev_resolver(
283         new DeviceResolverDistributed(worker_env_.device_mgr));
284     std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
285         new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
286                                                dev_resolver.get(), worker_cache,
287                                                default_worker_name));
288     worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
289         config, worker_env_.device_mgr, std::move(dev_resolver),
290         std::move(param_resolver), MaybeCreateNcclCommunicator(), worker_cache,
291         default_worker_name));
292   }
293 
294   // Set up worker environment.
295   worker_env_.session_mgr = new SessionMgr(
296       &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
297       std::unique_ptr<WorkerCacheInterface>(worker_cache),
298       [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
299         WorkerCacheFactoryOptions options(server_def);
300         return WorkerCacheFactory(options, worker_cache);
301       });
302   worker_env_.compute_pool = ComputePool(sess_opts);
303 
304   // Finish setting up master environment.
305   master_env_.ops = OpRegistry::Global();
306   master_env_.worker_cache = worker_cache;
307   master_env_.collective_executor_mgr =
308       worker_env_.collective_executor_mgr.get();
309   StatsPublisherFactory stats_factory = opts.stats_factory;
310   master_env_.master_session_factory =
311       [config, stats_factory](
312           SessionOptions options, const MasterEnv* env,
313           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
314           std::unique_ptr<WorkerCacheInterface> worker_cache,
315           std::unique_ptr<DeviceSet> device_set,
316           std::vector<string> filtered_worker_list) {
317         options.config.MergeFrom(config);
318         return new MasterSession(options, env, std::move(remote_devs),
319                                  std::move(worker_cache), std::move(device_set),
320                                  std::move(filtered_worker_list),
321                                  stats_factory);
322       };
323   master_env_.worker_cache_factory =
324       [this](const WorkerCacheFactoryOptions& options,
325              WorkerCacheInterface** worker_cache) {
326         return WorkerCacheFactory(options, worker_cache);
327       };
328 
329   // Provide direct access to the master from in-process clients.
330   LocalMaster::Register(target(), master_impl_.get(),
331                         config.operation_timeout_in_ms());
332 
333   return Status::OK();
334 }
335 
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)336 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
337                                     GrpcChannelSpec* channel_spec) {
338   for (const auto& job : options.cluster_def->job()) {
339     std::map<int, string> host_ports;
340     for (const auto& task : job.tasks()) {
341       string& host_port = host_ports[task.first];
342       if (!host_port.empty()) {
343         return errors::InvalidArgument("JobDef for job \"", job.name(),
344                                        "\" specified two addresses for task \"",
345                                        task.first, "\": ", host_port, " and ",
346                                        task.second);
347       }
348       if (job.name() == *options.job_name && task.first == options.task_index) {
349         host_port = strings::StrCat(host_name_, ":", bound_port_);
350       } else {
351         host_port = task.second;
352       }
353     }
354     TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
355   }
356   return Status::OK();
357 }
358 
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)359 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
360                                       WorkerCacheInterface** worker_cache) {
361   if (options.job_name == nullptr || options.job_name->empty()) {
362     Status s = errors::InvalidArgument(
363         "The master (current machine) is not included in the provided "
364         "cluster_def. ",
365         options.cluster_def->DebugString());
366     LOG(WARNING) << s;
367     return s;
368   }
369 
370   GrpcChannelSpec channel_spec;
371   TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
372 
373   std::shared_ptr<GrpcChannelCache> channel_cache(
374       NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
375 
376   string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
377                                        "/task:", options.task_index);
378 
379   const string host_port = channel_cache->TranslateTask(name_prefix);
380   int requested_port;
381 
382   auto colon_index = host_port.find_last_of(':');
383   if (!strings::safe_strto32(host_port.substr(colon_index + 1),
384                              &requested_port)) {
385     return errors::Internal("Could not parse port for local server from \"",
386                             host_port, "\".");
387   }
388   if (requested_port != bound_port_) {
389     return errors::InvalidArgument("Requested port ", requested_port,
390                                    " differs from expected port ", bound_port_);
391   }
392   *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
393       channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
394   return Status::OK();
395 }
396 
Start()397 Status GrpcServer::Start() {
398   mutex_lock l(mu_);
399   switch (state_) {
400     case NEW: {
401       master_thread_.reset(
402           env_->StartThread(ThreadOptions(), "TF_master_service",
403                             [this] { master_service_->HandleRPCsLoop(); }));
404       worker_thread_.reset(
405           env_->StartThread(ThreadOptions(), "TF_worker_service",
406                             [this] { worker_service_->HandleRPCsLoop(); }));
407       eager_thread_.reset(
408           env_->StartThread(ThreadOptions(), "TF_eager_service",
409                             [this] { eager_service_->HandleRPCsLoop(); }));
410       state_ = STARTED;
411       LOG(INFO) << "Started server with target: " << target();
412       return Status::OK();
413     }
414     case STARTED:
415       LOG(INFO) << "Server already started (target: " << target() << ")";
416       return Status::OK();
417     case STOPPED:
418       return errors::FailedPrecondition("Server has stopped.");
419     default:
420       LOG(FATAL);
421   }
422 }
423 
AddMasterEagerContextToEagerService(const tensorflow::uint64 context_id,tensorflow::EagerContext * context)424 Status GrpcServer::AddMasterEagerContextToEagerService(
425     const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
426   auto* eager_service =
427       static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
428   return eager_service->CreateMasterContext(context_id, context);
429 }
430 
UpdateServerDef(const ServerDef & server_def)431 Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
432   mutex_lock l(mu_);
433   server_def_ = server_def;
434   WorkerCacheInterface* worker_cache;
435   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
436   TF_RETURN_IF_ERROR(
437       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
438   if (worker_cache == nullptr) {
439     return errors::InvalidArgument(
440         "Failed to build worker cache with the provided server def.");
441   }
442   // Transfer ownership of worker_cache to worker_env_.session_mgr.
443   worker_env_.session_mgr->ResetDefaultWorkerCache(worker_cache);
444 
445   string default_worker_name;
446   string unused;
447   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
448                                         &default_worker_name, &unused)) {
449     return errors::Internal("Could not parse worker name.");
450   }
451   std::unique_ptr<DeviceResolverDistributed> dev_resolver(
452       new DeviceResolverDistributed(worker_env_.device_mgr));
453   std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
454       new CollectiveParamResolverDistributed(
455           server_def_.default_session_config(), worker_env_.device_mgr,
456           dev_resolver.get(), worker_cache, default_worker_name));
457   worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
458       server_def_.default_session_config(), worker_env_.device_mgr,
459       std::move(dev_resolver), std::move(param_resolver),
460       MaybeCreateNcclCommunicator(), worker_cache, default_worker_name));
461 
462   master_env_.worker_cache = worker_cache;
463   master_env_.collective_executor_mgr =
464       worker_env_.collective_executor_mgr.get();
465   return Status::OK();
466 }
467 
Stop()468 Status GrpcServer::Stop() {
469   mutex_lock l(mu_);
470   switch (state_) {
471     case NEW:
472       state_ = STOPPED;
473       return Status::OK();
474     case STARTED:
475       return errors::Unimplemented(
476           "Clean shutdown is not currently implemented");
477     case STOPPED:
478       LOG(INFO) << "Server already stopped (target: " << target() << ")";
479       return Status::OK();
480     default:
481       LOG(FATAL);
482   }
483 }
484 
Join()485 Status GrpcServer::Join() {
486   mutex_lock l(mu_);
487   switch (state_) {
488     case NEW:
489       // Prevent the server from being started subsequently.
490       state_ = STOPPED;
491       return Status::OK();
492     case STARTED:
493     case STOPPED:
494       master_thread_.reset();
495       worker_thread_.reset();
496       eager_thread_.reset();
497       return Status::OK();
498     default:
499       LOG(FATAL);
500   }
501 }
502 
target() const503 const string GrpcServer::target() const {
504   return strings::StrCat("grpc://", host_name_, ":", bound_port_);
505 }
506 
GetServerCredentials(const ServerDef & server_def) const507 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
508     const ServerDef& server_def) const {
509   return ::grpc::InsecureServerCredentials();
510 }
511 
GetChannelCreationFunction() const512 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
513   // We can do this because SparseGrpcChannelCache is robust to nullptr being
514   // returned by the channel creation function
515   return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
516 }
517 
CreateMaster(MasterEnv * master_env)518 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
519   return std::unique_ptr<Master>(new Master(master_env, 0.0));
520 }
521 
522 /* static */
Create(const ServerDef & server_def,Env * env,const DeviceMgr * local_device_mgr,std::unique_ptr<ServerInterface> * out_server)523 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
524                           const DeviceMgr* local_device_mgr,
525                           std::unique_ptr<ServerInterface>* out_server) {
526   std::unique_ptr<GrpcServer> ret(
527       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
528   GrpcServerOptions options;
529   options.rendezvous_mgr_func = NewRpcRendezvousMgr;
530   options.local_device_mgr = local_device_mgr;
531   Status s = ret->Init(options);
532   if (!s.ok()) {
533     LOG(ERROR) << s;
534     return s;
535   }
536   *out_server = std::move(ret);
537   return Status::OK();
538 }
539 
540 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)541 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
542                           std::unique_ptr<ServerInterface>* out_server) {
543   return Create(server_def, env, nullptr, out_server);
544 }
545 
546 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)547 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
548                           std::unique_ptr<GrpcServer>* out_server) {
549   std::unique_ptr<ServerInterface> server;
550   Status s = Create(server_def, env, nullptr, &server);
551   if (!s.ok()) {
552     return s;
553   }
554   out_server->reset(dynamic_cast<GrpcServer*>(server.release()));
555   return Status::OK();
556 }
557 
558 namespace {
559 
560 class GrpcServerFactory : public ServerFactory {
561  public:
AcceptsOptions(const ServerDef & server_def)562   bool AcceptsOptions(const ServerDef& server_def) override {
563     return server_def.protocol() == "grpc";
564   }
565 
NewServer(const ServerDef & server_def,const Options & options,std::unique_ptr<ServerInterface> * out_server)566   Status NewServer(const ServerDef& server_def, const Options& options,
567                    std::unique_ptr<ServerInterface>* out_server) override {
568     return GrpcServer::Create(server_def, Env::Default(),
569                               options.local_device_mgr, out_server);
570   }
571 };
572 
573 // Registers a `ServerFactory` for `GrpcServer` instances.
574 class GrpcServerRegistrar {
575  public:
GrpcServerRegistrar()576   GrpcServerRegistrar() {
577     ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
578   }
579 };
580 static GrpcServerRegistrar registrar;
581 
582 }  // namespace
583 }  // namespace tensorflow
584