• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17 
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 
22 #include "grpc++/grpc++.h"
23 #include "grpc++/security/credentials.h"
24 #include "grpc++/server_builder.h"
25 #include "grpc/support/alloc.h"
26 
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/process_util.h"
30 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
31 #include "tensorflow/core/distributed_runtime/local_master.h"
32 #include "tensorflow/core/distributed_runtime/master.h"
33 #include "tensorflow/core/distributed_runtime/master_env.h"
34 #include "tensorflow/core/distributed_runtime/master_session.h"
35 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
40 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
41 #include "tensorflow/core/distributed_runtime/server_lib.h"
42 #include "tensorflow/core/distributed_runtime/worker_env.h"
43 #include "tensorflow/core/framework/op.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/env.h"
46 #include "tensorflow/core/platform/mem.h"
47 #include "tensorflow/core/public/session_options.h"
48 
49 namespace tensorflow {
50 
51 namespace {
52 
53 // Define an option subclass in order to disable SO_REUSEPORT for the
54 // server socket.
55 class NoReusePortOption : public ::grpc::ServerBuilderOption {
56  public:
UpdateArguments(::grpc::ChannelArguments * args)57   void UpdateArguments(::grpc::ChannelArguments* args) override {
58     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
59   }
60 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)61   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
62                          plugins) override {}
63 };
64 
65 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)66 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
67   return new RpcRendezvousMgr(env);
68 }
69 
70 }  // namespace
71 
GrpcServer(const ServerDef & server_def,Env * env)72 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
73     : server_def_(server_def), env_(env), state_(NEW) {}
74 
~GrpcServer()75 GrpcServer::~GrpcServer() {
76   TF_CHECK_OK(Stop());
77   TF_CHECK_OK(Join());
78 
79   delete master_service_;
80   delete worker_service_;
81 
82   // TODO(mrry): Refactor the *Env classes so that it is less fiddly
83   // to destroy them.
84 
85   // Shut down all outstanding rendezvous.
86   delete worker_env_.rendezvous_mgr;
87 
88   // We must delete graph_mgr before device_mgr, due to shared
89   // ownership of OpKernels in the executors. (The graph_mgr will
90   // free all stateless OpKernels, and pass over borrowed stateful
91   // OpKernels, which are also held in their respective devices'
92   // OpSegments.)
93   if (worker_env_.session_mgr != nullptr) {
94     delete worker_env_.session_mgr;  // Deletes graph_mgr's.
95   } else {
96     // Note: session_mgr's legacy_session_ deletes device_mgr now.
97     delete worker_env_.device_mgr;
98   }
99 
100   // Do not delete (as these are not owned by the server):
101   // - master_env_.env
102   // - worker_env_.env
103   // - worker_env_.compute_pool
104 }
105 
Init(ServiceInitFunction service_func,const RendezvousMgrCreationFunction & rendezvous_mgr_func,const WorkerCreationFunction & worker_func)106 Status GrpcServer::Init(
107     ServiceInitFunction service_func,
108     const RendezvousMgrCreationFunction& rendezvous_mgr_func,
109     const WorkerCreationFunction& worker_func) {
110   mutex_lock l(mu_);
111   CHECK_EQ(state_, NEW);
112   master_env_.env = env_;
113   worker_env_.env = env_;
114 
115   SessionOptions sess_opts;
116   ConfigProto config = server_def_.default_session_config();
117   sess_opts.config = config;
118 
119   // Configure shared devices between master and worker.
120   string name_prefix =
121       strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
122                       "/task:", server_def_.task_index());
123   TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
124                                                &master_env_.local_devices));
125   worker_env_.local_devices = master_env_.local_devices;
126   worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
127   worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
128                                    ? new RpcRendezvousMgr(&worker_env_)
129                                    : rendezvous_mgr_func(&worker_env_);
130   string unused;
131   string default_worker_name;
132   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
133                                         &default_worker_name, &unused)) {
134     return errors::Internal("Could not parse worker name.");
135   }
136 
137   // Look up the port that has been requested for this task in `server_def_`.
138   int requested_port = -1;
139   for (const auto& job : server_def_.cluster().job()) {
140     if (job.name() == server_def_.job_name()) {
141       auto iter = job.tasks().find(server_def_.task_index());
142       if (iter == job.tasks().end()) {
143         return errors::InvalidArgument("Task ", server_def_.task_index(),
144                                        " was not defined in job \"",
145                                        server_def_.job_name(), "\"");
146       }
147       const std::vector<string> hostname_port =
148           str_util::Split(iter->second, ':');
149       if (hostname_port.size() != 2 ||
150           !strings::safe_strto32(hostname_port[1], &requested_port)) {
151         return errors::InvalidArgument(
152             "Could not parse port for local server from \"", iter->second,
153             "\"");
154       } else {
155         break;
156       }
157     }
158   }
159   if (requested_port == -1) {
160     return errors::Internal("Job \"", server_def_.job_name(),
161                             "\" was not defined in cluster");
162   }
163 
164   // N.B. The order of initialization here is intricate, because we
165   // wish to allow `requested_port == 0` (for choosing any port,
166   // mostly for testing). Therefore, the construction of the channel
167   // and worker caches depends on `bound_port_`, which is not set
168   // until we call `builder.BuildAndStart()`. We must create the
169   // service objects before calling `builder.BuildAndStart()`, but
170   // `master_env_` and `worker_env_` are only partially
171   // configured. However, this is not dangerous, because we do not
172   // start serving requests until `this->Start()` is called, which
173   // happens after this method returns.
174   //
175   // TODO(mrry): Provide a general mechanism for dynamically setting
176   // the identities of tasks in the worker pool after the service is
177   // running.
178   ::grpc::ServerBuilder builder;
179   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
180                            GetServerCredentials(server_def_), &bound_port_);
181   builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
182   builder.SetOption(
183       std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
184   master_impl_ = CreateMaster(&master_env_);
185   master_service_ = NewGrpcMasterService(
186       master_impl_.get(), config.operation_timeout_in_ms(), &builder);
187   worker_impl_ =
188       worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
189   worker_service_ =
190       NewGrpcWorkerService(worker_impl_.get(), &builder).release();
191   // extra service:
192   if (service_func != nullptr) {
193     service_func(&worker_env_, &builder);
194   }
195   server_ = builder.BuildAndStart();
196 
197   if (!server_) {
198     return errors::Unknown("Could not start gRPC server");
199   }
200 
201   WorkerCacheInterface* worker_cache;
202   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
203   TF_RETURN_IF_ERROR(
204       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
205   CHECK_NE(nullptr, worker_cache);
206 
207   // Set up worker environment.
208   worker_env_.session_mgr = new SessionMgr(
209       &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
210       std::unique_ptr<WorkerCacheInterface>(worker_cache),
211       [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
212         WorkerCacheFactoryOptions options(server_def);
213         return WorkerCacheFactory(options, worker_cache);
214       });
215   worker_env_.compute_pool = ComputePool(sess_opts);
216 
217   // Finish setting up master environment.
218   master_env_.ops = OpRegistry::Global();
219   master_env_.worker_cache = worker_cache;
220   master_env_.master_session_factory =
221       [config](
222           SessionOptions options, const MasterEnv* env,
223           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
224           std::unique_ptr<WorkerCacheInterface> worker_cache,
225           std::unique_ptr<DeviceSet> device_set) {
226         options.config.MergeFrom(config);
227         return new MasterSession(options, env, std::move(remote_devs),
228                                  std::move(worker_cache), std::move(device_set),
229                                  CreateNoOpStatsPublisher);
230       };
231   master_env_.worker_cache_factory =
232       [this](const WorkerCacheFactoryOptions& options,
233              WorkerCacheInterface** worker_cache) {
234         return WorkerCacheFactory(options, worker_cache);
235       };
236 
237   // Provide direct access to the master from in-process clients.
238   LocalMaster::Register(target(), master_impl_.get(),
239                         config.operation_timeout_in_ms());
240 
241   return Status::OK();
242 }
243 
Init(ServiceInitFunction service_func,const RendezvousMgrCreationFunction & rendezvous_mgr_func)244 Status GrpcServer::Init(
245     ServiceInitFunction service_func,
246     const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
247   return Init(service_func, rendezvous_mgr_func, nullptr);
248 }
249 
Init()250 Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
251 
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)252 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
253                                     GrpcChannelSpec* channel_spec) {
254   for (const auto& job : options.cluster_def->job()) {
255     std::map<int, string> host_ports;
256     for (const auto& task : job.tasks()) {
257       string& host_port = host_ports[task.first];
258       if (!host_port.empty()) {
259         return errors::InvalidArgument("JobDef for job \"", job.name(),
260                                        "\" specified two addresses for task \"",
261                                        task.first, "\": ", host_port, " and ",
262                                        task.second);
263       }
264       if (job.name() == *options.job_name && task.first == options.task_index) {
265         host_port = strings::StrCat("localhost:", bound_port_);
266       } else {
267         host_port = task.second;
268       }
269     }
270     TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
271   }
272   return Status::OK();
273 }
274 
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)275 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
276                                       WorkerCacheInterface** worker_cache) {
277   if (options.job_name == nullptr || options.job_name->empty()) {
278     Status s = errors::InvalidArgument(
279         "The master (current machine) is not included in the provided "
280         "cluster_def. ",
281         options.cluster_def->DebugString());
282     LOG(WARNING) << s;
283     return s;
284   }
285 
286   GrpcChannelSpec channel_spec;
287   TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
288 
289   std::unique_ptr<GrpcChannelCache> channel_cache(
290       NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
291 
292   string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
293                                        "/task:", options.task_index);
294 
295   const string host_port = channel_cache->TranslateTask(name_prefix);
296   int requested_port;
297 
298   if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
299                              &requested_port)) {
300     return errors::Internal("Could not parse port for local server from \"",
301                             channel_cache->TranslateTask(name_prefix), "\".");
302   }
303   if (requested_port != bound_port_) {
304     return errors::InvalidArgument("Requested port ", requested_port,
305                                    " differs from expected port ", bound_port_);
306   }
307 
308   *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
309       channel_cache.release(), worker_impl_.get(), name_prefix);
310   return Status::OK();
311 }
312 
Start()313 Status GrpcServer::Start() {
314   mutex_lock l(mu_);
315   switch (state_) {
316     case NEW: {
317       master_thread_.reset(
318           env_->StartThread(ThreadOptions(), "TF_master_service",
319                             [this] { master_service_->HandleRPCsLoop(); }));
320       worker_thread_.reset(
321           env_->StartThread(ThreadOptions(), "TF_worker_service",
322                             [this] { worker_service_->HandleRPCsLoop(); }));
323       state_ = STARTED;
324       LOG(INFO) << "Started server with target: " << target();
325       return Status::OK();
326     }
327     case STARTED:
328       LOG(INFO) << "Server already started (target: " << target() << ")";
329       return Status::OK();
330     case STOPPED:
331       return errors::FailedPrecondition("Server has stopped.");
332     default:
333       LOG(FATAL);
334   }
335 }
336 
Stop()337 Status GrpcServer::Stop() {
338   mutex_lock l(mu_);
339   switch (state_) {
340     case NEW:
341       state_ = STOPPED;
342       return Status::OK();
343     case STARTED:
344       return errors::Unimplemented(
345           "Clean shutdown is not currently implemented");
346     case STOPPED:
347       LOG(INFO) << "Server already stopped (target: " << target() << ")";
348       return Status::OK();
349     default:
350       LOG(FATAL);
351   }
352 }
353 
Join()354 Status GrpcServer::Join() {
355   mutex_lock l(mu_);
356   switch (state_) {
357     case NEW:
358       // Prevent the server from being started subsequently.
359       state_ = STOPPED;
360       return Status::OK();
361     case STARTED:
362     case STOPPED:
363       master_thread_.reset();
364       worker_thread_.reset();
365       return Status::OK();
366     default:
367       LOG(FATAL);
368   }
369 }
370 
target() const371 const string GrpcServer::target() const {
372   return strings::StrCat("grpc://localhost:", bound_port_);
373 }
374 
GetServerCredentials(const ServerDef & server_def) const375 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
376     const ServerDef& server_def) const {
377   return ::grpc::InsecureServerCredentials();
378 }
379 
GetChannelCreationFunction() const380 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
381   // We can do this because SparseGrpcChannelCache is robust to nullptr being
382   // returned by the channel creation function
383   return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
384 }
385 
CreateMaster(MasterEnv * master_env)386 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
387   return std::unique_ptr<Master>(new Master(master_env, 0.0));
388 }
389 
390 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)391 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
392                           std::unique_ptr<ServerInterface>* out_server) {
393   std::unique_ptr<GrpcServer> ret(
394       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
395   ServiceInitFunction service_func = nullptr;
396   TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
397   *out_server = std::move(ret);
398   return Status::OK();
399 }
400 
401 namespace {
402 
403 class GrpcServerFactory : public ServerFactory {
404  public:
AcceptsOptions(const ServerDef & server_def)405   bool AcceptsOptions(const ServerDef& server_def) override {
406     return server_def.protocol() == "grpc";
407   }
408 
NewServer(const ServerDef & server_def,std::unique_ptr<ServerInterface> * out_server)409   Status NewServer(const ServerDef& server_def,
410                    std::unique_ptr<ServerInterface>* out_server) override {
411     return GrpcServer::Create(server_def, Env::Default(), out_server);
412   }
413 };
414 
415 // Registers a `ServerFactory` for `GrpcServer` instances.
416 class GrpcServerRegistrar {
417  public:
GrpcServerRegistrar()418   GrpcServerRegistrar() {
419     gpr_allocation_functions alloc_fns;
420     memset(&alloc_fns, 0, sizeof(alloc_fns));
421     alloc_fns.malloc_fn = port::Malloc;
422     alloc_fns.realloc_fn = port::Realloc;
423     alloc_fns.free_fn = port::Free;
424     gpr_set_allocation_functions(alloc_fns);
425     ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
426   }
427 };
428 static GrpcServerRegistrar registrar;
429 
430 }  // namespace
431 }  // namespace tensorflow
432