• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef TENSORFLOW_USE_VERBS
17 
18 #include "tensorflow/contrib/verbs/verbs_server_lib.h"
19 
20 #include "grpc/support/alloc.h"
21 
22 #include "tensorflow/contrib/verbs/rdma_mgr.h"
23 #include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
24 #include "tensorflow/core/distributed_runtime/server_lib.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/env.h"
27 
28 namespace tensorflow {
29 
30 namespace {
31 // static utility function
NewRdmaRendezvousMgr(const WorkerEnv * env)32 RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) {
33   return new RdmaRendezvousMgr(env);
34 }
35 
36 }  // namespace
37 
VerbsServer(const ServerDef & server_def,Env * env)38 VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
39     : GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
40 
~VerbsServer()41 VerbsServer::~VerbsServer() {
42   TF_CHECK_OK(Stop());
43   TF_CHECK_OK(Join());
44   delete rdma_mgr_;
45   delete verbs_service_;
46   delete channel_cache_;
47 }
48 
ChannelCacheFactory(const ServerDef & server_def,GrpcChannelCache ** channel_cache)49 Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
50                                         GrpcChannelCache** channel_cache) {
51   string name_prefix =
52       strings::StrCat("/job:", server_def.job_name(), "/replica:0",
53                       "/task:", server_def.task_index());
54 
55   GrpcChannelSpec channel_spec;
56   TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
57 
58   *channel_cache =
59       NewGrpcChannelCache(channel_spec, GetChannelCreationFunction());
60 
61   const string host_port = (*channel_cache)->TranslateTask(name_prefix);
62   int requested_port;
63 
64   if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
65                              &requested_port)) {
66     return errors::Internal("Could not parse port for local server from \"",
67                             (*channel_cache)->TranslateTask(name_prefix),
68                             "\".");
69   }
70   if (requested_port != bound_port()) {
71     return errors::InvalidArgument("Requested port ", requested_port,
72                                    " differs from expected port ",
73                                    bound_port());
74   }
75 
76   return Status::OK();
77 }
78 
Init(ServiceInitFunction service_func,RendezvousMgrCreationFunction rendezvous_mgr_func)79 Status VerbsServer::Init(ServiceInitFunction service_func,
80                          RendezvousMgrCreationFunction rendezvous_mgr_func) {
81   Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
82   {
83     mutex_lock l(mu_);
84     CHECK_EQ(verbs_state_, DISCONNECTED);
85     CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
86     rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
87     // set rdma_mgr for verbs_service and rdma_rendezvous_mgr
88     verbs_service_->SetRdmaMgr(rdma_mgr_);
89     dynamic_cast<RdmaRendezvousMgr*>(worker_env()->rendezvous_mgr)
90         ->SetRdmaMgr(rdma_mgr_);
91   }
92   return s;
93 }
94 
Start()95 Status VerbsServer::Start() {
96   Status s = GrpcServer::Start();
97   {
98     mutex_lock l(mu_);
99     if (verbs_state_ == DISCONNECTED) {
100       // verbs_thread needs to be initiated
101       // before rdma_mgr sets up the rdma channels.
102       verbs_thread_.reset(worker_env()->env->StartThread(
103           ThreadOptions(), "TF_verbs_service",
104           [this] { verbs_service_->HandleRPCsLoop(); }));
105       rdma_mgr_->SetupChannels();
106       CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!";
107       rdma_mgr_->InitAllocators();
108       verbs_state_ = CONNECTED;
109     }
110   }
111   return s;
112 }
113 
Join()114 Status VerbsServer::Join() {
115   Status s = GrpcServer::Join();
116   {
117     mutex_lock l(mu_);
118     if (verbs_state_ == CONNECTED) {
119       verbs_state_ = DISCONNECTED;
120       verbs_thread_.reset();
121     }
122   }
123   return s;
124 }
125 
126 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)127 Status VerbsServer::Create(const ServerDef& server_def, Env* env,
128                            std::unique_ptr<ServerInterface>* out_server) {
129   std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
130   ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
131                                             ::grpc::ServerBuilder* builder) {
132     return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
133   };
134   TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
135   *out_server = std::move(ret);
136   return Status::OK();
137 }
138 
139 namespace {
140 
141 class VerbsServerFactory : public ServerFactory {
142  public:
AcceptsOptions(const ServerDef & server_def)143   bool AcceptsOptions(const ServerDef& server_def) override {
144     return server_def.protocol() == "grpc+verbs";
145   }
146 
NewServer(const ServerDef & server_def,std::unique_ptr<ServerInterface> * out_server)147   Status NewServer(const ServerDef& server_def,
148                    std::unique_ptr<ServerInterface>* out_server) override {
149     return VerbsServer::Create(server_def, Env::Default(), out_server);
150   }
151 };
152 
153 // Registers a `ServerFactory` for `VerbsServer` instances.
154 class VerbsServerRegistrar {
155  public:
VerbsServerRegistrar()156   VerbsServerRegistrar() {
157     gpr_allocation_functions alloc_fns;
158     alloc_fns.malloc_fn = port::Malloc;
159     alloc_fns.realloc_fn = port::Realloc;
160     alloc_fns.free_fn = port::Free;
161     gpr_set_allocation_functions(alloc_fns);
162     ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
163   }
164 };
165 static VerbsServerRegistrar registrar;
166 
167 }  // namespace
168 }  // namespace tensorflow
169 
170 #endif
171