1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef TENSORFLOW_USE_VERBS
17
18 #include "tensorflow/contrib/verbs/verbs_server_lib.h"
19
20 #include "grpc/support/alloc.h"
21
22 #include "tensorflow/contrib/verbs/rdma_mgr.h"
23 #include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
24 #include "tensorflow/core/distributed_runtime/server_lib.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/env.h"
27
28 namespace tensorflow {
29
30 namespace {
31 // static utility function
NewRdmaRendezvousMgr(const WorkerEnv * env)32 RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) {
33 return new RdmaRendezvousMgr(env);
34 }
35
36 } // namespace
37
VerbsServer(const ServerDef & server_def,Env * env)38 VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
39 : GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
40
~VerbsServer()41 VerbsServer::~VerbsServer() {
42 TF_CHECK_OK(Stop());
43 TF_CHECK_OK(Join());
44 delete rdma_mgr_;
45 delete verbs_service_;
46 delete channel_cache_;
47 }
48
ChannelCacheFactory(const ServerDef & server_def,GrpcChannelCache ** channel_cache)49 Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
50 GrpcChannelCache** channel_cache) {
51 string name_prefix =
52 strings::StrCat("/job:", server_def.job_name(), "/replica:0",
53 "/task:", server_def.task_index());
54
55 GrpcChannelSpec channel_spec;
56 TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
57
58 *channel_cache =
59 NewGrpcChannelCache(channel_spec, GetChannelCreationFunction());
60
61 const string host_port = (*channel_cache)->TranslateTask(name_prefix);
62 int requested_port;
63
64 if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
65 &requested_port)) {
66 return errors::Internal("Could not parse port for local server from \"",
67 (*channel_cache)->TranslateTask(name_prefix),
68 "\".");
69 }
70 if (requested_port != bound_port()) {
71 return errors::InvalidArgument("Requested port ", requested_port,
72 " differs from expected port ",
73 bound_port());
74 }
75
76 return Status::OK();
77 }
78
Init(ServiceInitFunction service_func,RendezvousMgrCreationFunction rendezvous_mgr_func)79 Status VerbsServer::Init(ServiceInitFunction service_func,
80 RendezvousMgrCreationFunction rendezvous_mgr_func) {
81 Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
82 {
83 mutex_lock l(mu_);
84 CHECK_EQ(verbs_state_, DISCONNECTED);
85 CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
86 rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
87 // set rdma_mgr for verbs_service and rdma_rendezvous_mgr
88 verbs_service_->SetRdmaMgr(rdma_mgr_);
89 dynamic_cast<RdmaRendezvousMgr*>(worker_env()->rendezvous_mgr)
90 ->SetRdmaMgr(rdma_mgr_);
91 }
92 return s;
93 }
94
Start()95 Status VerbsServer::Start() {
96 Status s = GrpcServer::Start();
97 {
98 mutex_lock l(mu_);
99 if (verbs_state_ == DISCONNECTED) {
100 // verbs_thread needs to be initiated
101 // before rdma_mgr sets up the rdma channels.
102 verbs_thread_.reset(worker_env()->env->StartThread(
103 ThreadOptions(), "TF_verbs_service",
104 [this] { verbs_service_->HandleRPCsLoop(); }));
105 rdma_mgr_->SetupChannels();
106 CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!";
107 rdma_mgr_->InitAllocators();
108 verbs_state_ = CONNECTED;
109 }
110 }
111 return s;
112 }
113
Join()114 Status VerbsServer::Join() {
115 Status s = GrpcServer::Join();
116 {
117 mutex_lock l(mu_);
118 if (verbs_state_ == CONNECTED) {
119 verbs_state_ = DISCONNECTED;
120 verbs_thread_.reset();
121 }
122 }
123 return s;
124 }
125
126 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)127 Status VerbsServer::Create(const ServerDef& server_def, Env* env,
128 std::unique_ptr<ServerInterface>* out_server) {
129 std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
130 ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
131 ::grpc::ServerBuilder* builder) {
132 return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
133 };
134 TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
135 *out_server = std::move(ret);
136 return Status::OK();
137 }
138
139 namespace {
140
141 class VerbsServerFactory : public ServerFactory {
142 public:
AcceptsOptions(const ServerDef & server_def)143 bool AcceptsOptions(const ServerDef& server_def) override {
144 return server_def.protocol() == "grpc+verbs";
145 }
146
NewServer(const ServerDef & server_def,std::unique_ptr<ServerInterface> * out_server)147 Status NewServer(const ServerDef& server_def,
148 std::unique_ptr<ServerInterface>* out_server) override {
149 return VerbsServer::Create(server_def, Env::Default(), out_server);
150 }
151 };
152
153 // Registers a `ServerFactory` for `VerbsServer` instances.
154 class VerbsServerRegistrar {
155 public:
VerbsServerRegistrar()156 VerbsServerRegistrar() {
157 gpr_allocation_functions alloc_fns;
158 alloc_fns.malloc_fn = port::Malloc;
159 alloc_fns.realloc_fn = port::Realloc;
160 alloc_fns.free_fn = port::Free;
161 gpr_set_allocation_functions(alloc_fns);
162 ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
163 }
164 };
165 static VerbsServerRegistrar registrar;
166
167 } // namespace
168 } // namespace tensorflow
169
170 #endif
171