• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17 
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 #include "grpcpp/server_builder.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
30 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
31 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
32 #include "tensorflow/core/distributed_runtime/local_master.h"
33 #include "tensorflow/core/distributed_runtime/master.h"
34 #include "tensorflow/core/distributed_runtime/master_env.h"
35 #include "tensorflow/core/distributed_runtime/master_session.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
42 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
43 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
44 #include "tensorflow/core/distributed_runtime/server_lib.h"
45 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
46 #include "tensorflow/core/distributed_runtime/worker_env.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/nccl/collective_communicator.h"
51 #include "tensorflow/core/platform/cpu_info.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mem.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/util/env_var.h"
58 
59 namespace tensorflow {
60 
61 namespace {
62 
63 // Define an option subclass in order to disable SO_REUSEPORT for the
64 // server socket.
65 class NoReusePortOption : public ::grpc::ServerBuilderOption {
66  public:
UpdateArguments(::grpc::ChannelArguments * args)67   void UpdateArguments(::grpc::ChannelArguments* args) override {
68     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
69   }
70 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)71   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
72                          plugins) override {}
73 };
74 
75 // Define an option subclass in order to enable SO_REUSEPORT for the
76 // server socket.
77 class ReusePortOption : public ::grpc::ServerBuilderOption {
78  public:
UpdateArguments(::grpc::ChannelArguments * args)79   void UpdateArguments(::grpc::ChannelArguments* args) override {
80     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 1);
81   }
82 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)83   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
84                          plugins) override {}
85 };
86 
87 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)88 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
89   return new RpcRendezvousMgr(env);
90 }
91 
92 }  // namespace
93 
GrpcServer(const ServerDef & server_def,Env * env)94 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
95     : env_(env), state_(NEW), server_def_(server_def) {}
96 
~GrpcServer()97 GrpcServer::~GrpcServer() {
98   TF_CHECK_OK(Stop());
99   TF_CHECK_OK(Join());
100 
101   delete master_service_;
102   delete worker_service_;
103   delete eager_service_;
104 
105   for (auto& kv : extra_services_) {
106     AsyncServiceInterface* service = kv.second;
107     delete service;
108   }
109 
110   // TODO(mrry): Refactor the *Env classes so that it is less fiddly
111   // to destroy them.
112 
113   // Shut down all outstanding rendezvous.
114   delete worker_env_.rendezvous_mgr;
115 
116   // We must delete graph_mgr before device_mgr, due to shared
117   // ownership of OpKernels in the executors. (The graph_mgr will
118   // free all stateless OpKernels, and pass over borrowed stateful
119   // OpKernels, which are also held in their respective devices'
120   // OpSegments.)
121   if (worker_env_.session_mgr != nullptr) {
122     delete worker_env_.session_mgr;  // Deletes graph_mgr's.
123   }
124 
125   // Do not delete (as these are not owned by the server):
126   // - master_env_.env
127   // - worker_env_.env
128   // - worker_env_.compute_pool
129 }
130 
131 // Look up the requested host name and port for this task in `server_def`.
GetHostAndPort(const ServerDef & server_def,string * host_name,int * port) const132 Status GrpcServer::GetHostAndPort(const ServerDef& server_def,
133                                   string* host_name, int* port) const {
134   *port = -1;
135   *host_name = "localhost";
136   for (const auto& job : server_def.cluster().job()) {
137     if (job.name() == server_def.job_name()) {
138       auto iter = job.tasks().find(server_def.task_index());
139       if (iter == job.tasks().end()) {
140         return errors::Internal("Task ", server_def.task_index(),
141                                 " was not defined in job \"",
142                                 server_def.job_name(), "\"");
143       }
144 
145       if (server_def.port() != 0) {
146         *port = server_def.port();
147       } else {
148         auto colon_index = iter->second.find_last_of(':');
149         if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
150                                    port)) {
151           return errors::InvalidArgument(
152               "Could not parse port for local server from \"", iter->second,
153               "\".");
154         }
155 
156         if (colon_index != string::npos &&
157             !iter->second.substr(0, colon_index).empty()) {
158           *host_name = iter->second.substr(0, colon_index);
159         }
160       }
161       break;
162     }
163   }
164   if (*port == -1) {
165     return errors::Internal("Job \"", server_def.job_name(),
166                             "\" was not defined in cluster");
167   }
168 
169   return Status::OK();
170 }
171 
Init(const GrpcServerOptions & opts)172 Status GrpcServer::Init(const GrpcServerOptions& opts) {
173   mutex_lock l(mu_);
174   CHECK_EQ(state_, NEW);
175   master_env_.env = env_;
176   worker_env_.env = env_;
177 
178   // Check parameters before DeviceFactory::AddDevices,
179   // otherwise if 'task_index=-1' the program will abort.
180 
181   int requested_port;
182   TF_RETURN_IF_ERROR(GetHostAndPort(server_def_, &host_name_, &requested_port));
183 
184   SessionOptions sess_opts;
185   VLOG(3) << "Grpc Server Init Definition: " << server_def_.DebugString();
186   ConfigProto config = server_def_.default_session_config();
187   sess_opts.config = config;
188 
189   // Configure shared devices between master and worker.
190   string name_prefix =
191       strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
192                       "/task:", server_def_.task_index());
193   if (opts.local_device_mgr == nullptr) {
194     std::vector<std::unique_ptr<Device>> devices;
195     TF_RETURN_IF_ERROR(
196         DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
197     worker_env_.device_mgr = new DynamicDeviceMgr(std::move(devices));
198     owned_device_manager_.reset(worker_env_.device_mgr);
199   } else {
200     worker_env_.device_mgr = opts.local_device_mgr;
201     owned_device_manager_.reset(nullptr);
202   }
203   worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
204   master_env_.local_devices = worker_env_.device_mgr->ListDevices();
205   worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
206                                    ? new RpcRendezvousMgr(&worker_env_)
207                                    : opts.rendezvous_mgr_func(&worker_env_);
208   string unused;
209   string default_worker_name;
210   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
211                                         &default_worker_name, &unused)) {
212     return errors::Internal("Could not parse worker name.");
213   }
214 
215   // N.B. The order of initialization here is intricate, because we
216   // wish to allow `requested_port == 0` (for choosing any port,
217   // mostly for testing). Therefore, the construction of the channel
218   // and worker caches depends on `bound_port_`, which is not set
219   // until we call `builder.BuildAndStart()`. We must create the
220   // service objects before calling `builder.BuildAndStart()`, but
221   // `master_env_` and `worker_env_` are only partially
222   // configured. However, this is not dangerous, because we do not
223   // start serving requests until `this->Start()` is called, which
224   // happens after this method returns.
225   //
226   // TODO(mrry): Provide a general mechanism for dynamically setting
227   // the identities of tasks in the worker pool after the service is
228   // running.
229   ::grpc::ServerBuilder builder;
230   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
231                            GetServerCredentials(server_def_), &bound_port_);
232   builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
233 
234   bool reuse_port = false;
235   const Status status =
236       ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port);
237   if (!status.ok()) {
238     LOG(ERROR) << status.error_message();
239   }
240   auto server_build_option =
241       reuse_port
242           ? std::unique_ptr<::grpc::ServerBuilderOption>(new ReusePortOption)
243           : std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption);
244   builder.SetOption(std::move(server_build_option));
245 
246   // Allow subclasses to specify more args to pass to the gRPC server.
247   MaybeMutateBuilder(&builder, requested_port);
248   master_impl_ = CreateMaster(&master_env_);
249   master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
250   worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
251                                   : NewGrpcWorker(&worker_env_, config);
252   worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
253                                          opts.worker_service_options)
254                         .release();
255   eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
256 
257   profiler_service_ = profiler::CreateProfilerService();
258   builder.RegisterService(profiler_service_.get());
259 
260   // Add any extra services to be started.
261   extra_services_ = ExtraServices(&builder);
262 
263   // extra service:
264   if (opts.service_func != nullptr) {
265     opts.service_func(&worker_env_, &builder);
266   }
267   server_ = builder.BuildAndStart();
268 
269   if (!server_) {
270     return errors::Unknown("Could not start gRPC server");
271   }
272   // Create the execution environment for the GRPC workers cache.
273   grpc_worker_env_.reset(CreateGrpcWorkerEnv());
274 
275   WorkerCacheInterface* worker_cache;
276   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
277   TF_RETURN_IF_ERROR(
278       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
279   CHECK_NE(nullptr, worker_cache);
280 
281   if (opts.collective_mgr_func) {
282     worker_env_.collective_executor_mgr.reset(
283         opts.collective_mgr_func(config, &worker_env_, worker_cache));
284     if (worker_env_.collective_executor_mgr == nullptr) {
285       return errors::Internal(
286           "collective_mgr_func did not return CollectiveExecutorMgr");
287     }
288   } else {
289     worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
290         config, worker_env_.device_mgr, MaybeCreateNcclCommunicator(config),
291         worker_cache, default_worker_name);
292   }
293 
294   // Set up worker environment.
295   worker_env_.session_mgr = new SessionMgr(
296       &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
297       std::unique_ptr<WorkerCacheInterface>(worker_cache),
298       [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
299         WorkerCacheFactoryOptions options(server_def);
300         return WorkerCacheFactory(options, worker_cache);
301       });
302   worker_env_.compute_pool = ComputePool(sess_opts);
303 
304   // Finish setting up master environment.
305   master_env_.ops = OpRegistry::Global();
306   master_env_.worker_cache = worker_cache;
307   master_env_.collective_executor_mgr =
308       worker_env_.collective_executor_mgr.get();
309   StatsPublisherFactory stats_factory = opts.stats_factory;
310   master_env_.master_session_factory =
311       [config, stats_factory](
312           SessionOptions options, const MasterEnv* env,
313           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
314           std::unique_ptr<WorkerCacheInterface> worker_cache,
315           std::unique_ptr<DeviceSet> device_set,
316           std::vector<string> filtered_worker_list) {
317         options.config.MergeFrom(config);
318         return new MasterSession(options, env, std::move(remote_devs),
319                                  std::move(worker_cache), std::move(device_set),
320                                  std::move(filtered_worker_list),
321                                  stats_factory);
322       };
323   master_env_.worker_cache_factory =
324       [this](const WorkerCacheFactoryOptions& options,
325              WorkerCacheInterface** worker_cache) {
326         return WorkerCacheFactory(options, worker_cache);
327       };
328 
329   // Provide direct access to the master from in-process clients.
330   LocalMaster::Register(target(), master_impl_.get(),
331                         config.operation_timeout_in_ms());
332 
333   return Status::OK();
334 }
335 
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)336 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
337                                     GrpcChannelSpec* channel_spec) {
338   for (const auto& job : options.cluster_def->job()) {
339     std::map<int, string> host_ports;
340     for (const auto& task : job.tasks()) {
341       string& host_port = host_ports[task.first];
342       if (!host_port.empty()) {
343         return errors::InvalidArgument("JobDef for job \"", job.name(),
344                                        "\" specified two addresses for task \"",
345                                        task.first, "\": ", host_port, " and ",
346                                        task.second);
347       }
348       if (job.name() == *options.job_name && task.first == options.task_index) {
349         host_port = strings::StrCat(host_name_, ":", bound_port_);
350       } else {
351         host_port = task.second;
352       }
353     }
354     TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
355   }
356   return Status::OK();
357 }
358 
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)359 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
360                                       WorkerCacheInterface** worker_cache) {
361   if (options.job_name == nullptr || options.job_name->empty()) {
362     Status s = errors::InvalidArgument(
363         "The master (current machine) is not included in the provided "
364         "cluster_def. ",
365         options.cluster_def->DebugString());
366     LOG(WARNING) << s;
367     return s;
368   }
369 
370   GrpcChannelSpec channel_spec;
371   TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
372 
373   if (options.rpc_options == nullptr) {
374     return errors::InvalidArgument(
375         "rpc_options not set in WorkerCacheFactoryOptions");
376   }
377   std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
378       channel_spec, GetChannelCreationFunction(), *options.rpc_options));
379 
380   string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
381                                        "/task:", options.task_index);
382 
383   const string host_port = channel_cache->TranslateTask(name_prefix);
384   int requested_port;
385 
386   auto colon_index = host_port.find_last_of(':');
387   if (!strings::safe_strto32(host_port.substr(colon_index + 1),
388                              &requested_port)) {
389     return errors::Internal("Could not parse port for local server from \"",
390                             host_port, "\".");
391   }
392   if (requested_port != bound_port_) {
393     return errors::InvalidArgument("Requested port ", requested_port,
394                                    " differs from expected port ", bound_port_);
395   }
396   *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
397       channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
398   return Status::OK();
399 }
400 
Start()401 Status GrpcServer::Start() {
402   mutex_lock l(mu_);
403   switch (state_) {
404     case NEW: {
405       master_thread_.reset(
406           env_->StartThread(ThreadOptions(), "TF_master_service",
407                             [this] { master_service_->HandleRPCsLoop(); }));
408       worker_thread_.reset(
409           env_->StartThread(ThreadOptions(), "TF_worker_service",
410                             [this] { worker_service_->HandleRPCsLoop(); }));
411       eager_thread_.reset(
412           env_->StartThread(ThreadOptions(), "TF_eager_service",
413                             [this] { eager_service_->HandleRPCsLoop(); }));
414 
415       for (const auto& kv : extra_services_) {
416         const std::string& service_name = kv.first;
417         AsyncServiceInterface* service = kv.second;
418         std::unique_ptr<Thread> extra_service_thread;
419         extra_service_thread.reset(env_->StartThread(
420             ThreadOptions(), service_name,
421             [service = service] { service->HandleRPCsLoop(); }));
422         extra_service_threads_.push_back(std::move(extra_service_thread));
423         VLOG(3) << "Started extra service: " << service_name;
424       }
425 
426       state_ = STARTED;
427       LOG(INFO) << "Started server with target: " << target();
428       return Status::OK();
429     }
430     case STARTED:
431       LOG(INFO) << "Server already started (target: " << target() << ")";
432       return Status::OK();
433     case STOPPED:
434       return errors::FailedPrecondition("Server has stopped.");
435     default:
436       LOG(FATAL);
437   }
438 }
439 
AddMasterEagerContextToEagerService(const tensorflow::uint64 context_id,tensorflow::EagerContext * context)440 Status GrpcServer::AddMasterEagerContextToEagerService(
441     const tensorflow::uint64 context_id, tensorflow::EagerContext* context) {
442   auto* eager_service =
443       static_cast<eager::GrpcEagerServiceImpl*>(eager_service_);
444   return eager_service->CreateMasterContext(context_id, context);
445 }
446 
UpdateServerDef(const ServerDef & server_def)447 Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
448   mutex_lock l(mu_);
449   server_def_ = server_def;
450   WorkerCacheInterface* worker_cache;
451   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
452   TF_RETURN_IF_ERROR(
453       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
454   if (worker_cache == nullptr) {
455     return errors::InvalidArgument(
456         "Failed to build worker cache with the provided server def.");
457   }
458   // Transfer ownership of worker_cache to worker_env_.session_mgr.
459   worker_env_.session_mgr->ResetDefaultWorkerCache(worker_cache);
460 
461   string default_worker_name;
462   string unused;
463   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
464                                         &default_worker_name, &unused)) {
465     return errors::Internal("Could not parse worker name.");
466   }
467   worker_env_.collective_executor_mgr = CreateProdRpcCollectiveExecutorMgr(
468       server_def_.default_session_config(), worker_env_.device_mgr,
469       MaybeCreateNcclCommunicator(server_def_.default_session_config()),
470       worker_cache, default_worker_name);
471 
472   master_env_.worker_cache = worker_cache;
473   master_env_.collective_executor_mgr =
474       worker_env_.collective_executor_mgr.get();
475   return Status::OK();
476 }
477 
478 // TODO(haoyuzhang): Remove this method once we have a mechanism to directly set
479 // field inside the RPC coordination service handler.
SetCoordinationServiceAgentInstance(CoordinationServiceAgent * agent)480 Status GrpcServer::SetCoordinationServiceAgentInstance(
481     CoordinationServiceAgent* agent) {
482   // No op, coordination service is not implemented in open source.
483   return Status::OK();
484 }
485 
Stop()486 Status GrpcServer::Stop() {
487   mutex_lock l(mu_);
488   switch (state_) {
489     case NEW:
490       state_ = STOPPED;
491       return Status::OK();
492     case STARTED:
493       return errors::Unimplemented(
494           "Clean shutdown is not currently implemented");
495     case STOPPED:
496       LOG(INFO) << "Server already stopped (target: " << target() << ")";
497       return Status::OK();
498     default:
499       LOG(FATAL);
500   }
501 }
502 
Join()503 Status GrpcServer::Join() {
504   mutex_lock l(mu_);
505   switch (state_) {
506     case NEW:
507       // Prevent the server from being started subsequently.
508       state_ = STOPPED;
509       return Status::OK();
510     case STARTED:
511     case STOPPED:
512       master_thread_.reset();
513       worker_thread_.reset();
514       eager_thread_.reset();
515       for (auto& thread : extra_service_threads_) {
516         thread.reset();
517       }
518       return Status::OK();
519     default:
520       LOG(FATAL);
521   }
522 }
523 
target() const524 const string GrpcServer::target() const {
525   return strings::StrCat("grpc://", host_name_, ":", bound_port_);
526 }
527 
GetServerCredentials(const ServerDef & server_def) const528 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
529     const ServerDef& server_def) const {
530   return ::grpc::InsecureServerCredentials();
531 }
532 
GetChannelCreationFunction() const533 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
534   // We can do this because SparseGrpcChannelCache is robust to nullptr being
535   // returned by the channel creation function
536   return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
537 }
538 
CreateMaster(MasterEnv * master_env)539 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
540   return std::unique_ptr<Master>(new Master(master_env, 0.0));
541 }
542 
543 /* static */
Create(const ServerDef & server_def,Env * env,DeviceMgr * local_device_mgr,std::unique_ptr<ServerInterface> * out_server)544 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
545                           DeviceMgr* local_device_mgr,
546                           std::unique_ptr<ServerInterface>* out_server) {
547   std::unique_ptr<GrpcServer> ret(
548       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
549   GrpcServerOptions options;
550   options.rendezvous_mgr_func = NewRpcRendezvousMgr;
551   options.local_device_mgr = local_device_mgr;
552   Status s = ret->Init(options);
553   if (!s.ok()) {
554     LOG(ERROR) << s;
555     return s;
556   }
557   *out_server = std::move(ret);
558   return Status::OK();
559 }
560 
561 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)562 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
563                           std::unique_ptr<ServerInterface>* out_server) {
564   return Create(server_def, env, nullptr, out_server);
565 }
566 
567 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)568 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
569                           std::unique_ptr<GrpcServer>* out_server) {
570   std::unique_ptr<ServerInterface> server;
571   Status s = Create(server_def, env, nullptr, &server);
572   if (!s.ok()) {
573     return s;
574   }
575   out_server->reset(dynamic_cast<GrpcServer*>(server.release()));
576   return Status::OK();
577 }
578 
579 namespace {
580 
581 class GrpcServerFactory : public ServerFactory {
582  public:
AcceptsOptions(const ServerDef & server_def)583   bool AcceptsOptions(const ServerDef& server_def) override {
584     return server_def.protocol() == "grpc";
585   }
586 
NewServer(const ServerDef & server_def,const Options & options,std::unique_ptr<ServerInterface> * out_server)587   Status NewServer(const ServerDef& server_def, const Options& options,
588                    std::unique_ptr<ServerInterface>* out_server) override {
589     return GrpcServer::Create(server_def, Env::Default(),
590                               options.local_device_mgr, out_server);
591   }
592 };
593 
594 // Registers a `ServerFactory` for `GrpcServer` instances.
595 class GrpcServerRegistrar {
596  public:
GrpcServerRegistrar()597   GrpcServerRegistrar() {
598     ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
599   }
600 };
601 static GrpcServerRegistrar registrar;
602 
603 }  // namespace
604 }  // namespace tensorflow
605