1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
17
18 #include <unordered_map>
19
20 #include "tensorflow/core/common_runtime/session_factory.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/local_master.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/protobuf/master.pb.h"
34
35 namespace tensorflow {
36
37 const char* const kSchemePrefix = "grpc://";
38 const size_t kSchemePrefixLength = strlen(kSchemePrefix);
39
GrpcSession(const SessionOptions & options)40 GrpcSession::GrpcSession(const SessionOptions& options)
41 : options_(options), current_graph_version_(-1) {}
42
~GrpcSession()43 GrpcSession::~GrpcSession() {}
44
45 /* static */
Create(const SessionOptions & options,std::unique_ptr<GrpcSession> * out_session)46 Status GrpcSession::Create(const SessionOptions& options,
47 std::unique_ptr<GrpcSession>* out_session) {
48 std::unique_ptr<GrpcSession> session(new GrpcSession(options));
49 std::unique_ptr<MasterInterface> master;
50 // For testing, we enable the client to disable the use of the local
51 // master registry, so that the RPC stack is exercised.
52 if (!options.config.rpc_options().use_rpc_for_inprocess_master()) {
53 master = LocalMaster::Lookup(options.target);
54 }
55 if (!master) {
56 SharedGrpcChannelPtr master_channel;
57 TF_RETURN_IF_ERROR(
58 NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
59 &options.config.rpc_options(), &master_channel));
60 master.reset(NewGrpcMaster(master_channel));
61 } else {
62 session->is_local_ = true;
63 }
64 session->SetRemoteMaster(std::move(master));
65 *out_session = std::move(session);
66 return Status::OK();
67 }
68
69 namespace {
70 // Re-encodes constant represented in tensor proto into
71 // tensor_content, which is slightly better (less copies and lower peak
72 // memory usage) when used with rpc subsystems.
ReEncodeConsts(GraphDef * gdef)73 void ReEncodeConsts(GraphDef* gdef) {
74 for (NodeDef& ndef : *(gdef->mutable_node())) {
75 if (ndef.op() == "Const") {
76 TensorProto* proto = nullptr;
77 for (auto& attr : *ndef.mutable_attr()) {
78 if (attr.first == "value") {
79 proto = attr.second.mutable_tensor();
80 }
81 }
82 if (proto != nullptr && proto->tensor_content().empty() &&
83 proto->ByteSizeLong() > 64) {
84 // If the constant is encoded with repeated proto fields and
85 // it is moderate large, we re-encode it in tensor_content as
86 // a Cord. This is mildly helpful for reducing the peak memory
87 // usage on the server side where GraphDef/NodeDef are copied
88 // quite often.
89 Tensor parsed(proto->dtype());
90 if (parsed.FromProto(*proto)) {
91 parsed.AsProtoTensorContent(proto);
92 }
93 }
94 }
95 }
96 }
97 } // namespace
98
SetHandleAndGraphVersion(string handle,int64 graph_version)99 void GrpcSession::SetHandleAndGraphVersion(string handle, int64 graph_version) {
100 mutex_lock l(mu_);
101 handle_ = std::move(handle);
102 current_graph_version_ = graph_version;
103 }
104
Handle(string * out_handle)105 Status GrpcSession::Handle(string* out_handle) {
106 mutex_lock l(mu_);
107 if (handle_.empty()) {
108 return errors::InvalidArgument("A session is not created yet....");
109 }
110 *out_handle = handle_;
111 return Status::OK();
112 }
113
CreateImpl(CallOptions * call_options,GraphDef graph)114 Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) {
115 {
116 mutex_lock l(mu_);
117 if (!handle_.empty()) {
118 return errors::InvalidArgument("A session is alive.");
119 }
120 }
121 CreateSessionRequest req;
122 *req.mutable_config() = options_.config;
123 req.mutable_graph_def()->Swap(&graph);
124 req.set_target(options_.target);
125 ReEncodeConsts(req.mutable_graph_def());
126 CreateSessionResponse resp;
127 Status s = master_->CreateSession(call_options, &req, &resp);
128 if (s.ok()) {
129 SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version());
130 }
131 return s;
132 }
133
Create(const GraphDef & graph)134 Status GrpcSession::Create(const GraphDef& graph) {
135 return Create(GraphDef(graph));
136 }
137
Create(const RunOptions & run_options,const GraphDef & graph)138 Status GrpcSession::Create(const RunOptions& run_options,
139 const GraphDef& graph) {
140 return Create(run_options, GraphDef(graph));
141 }
142
Create(GraphDef && graph)143 Status GrpcSession::Create(GraphDef&& graph) {
144 CallOptions call_options;
145 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
146 return CreateImpl(&call_options, std::move(graph));
147 }
148
Create(const RunOptions & run_options,GraphDef && graph)149 Status GrpcSession::Create(const RunOptions& run_options, GraphDef&& graph) {
150 CallOptions call_options;
151 call_options.SetTimeout(run_options.timeout_in_ms());
152 return CreateImpl(&call_options, std::move(graph));
153 }
154
ExtendImpl(CallOptions * call_options,GraphDef graph)155 Status GrpcSession::ExtendImpl(CallOptions* call_options, GraphDef graph) {
156 bool handle_is_empty;
157 {
158 mutex_lock l(mu_);
159 handle_is_empty = handle_.empty();
160 }
161 if (handle_is_empty) {
162 // Session was uninitialized, so simply initialize the session with 'graph'.
163 return Create(std::move(graph));
164 }
165 mutex_lock l(mu_);
166 ExtendSessionRequest req;
167 req.set_session_handle(handle_);
168 req.mutable_graph_def()->Swap(&graph);
169 req.set_current_graph_version(current_graph_version_);
170 ExtendSessionResponse resp;
171 Status s = master_->ExtendSession(call_options, &req, &resp);
172 if (s.ok()) {
173 current_graph_version_ = resp.new_graph_version();
174 }
175 return s;
176 }
177
Extend(const GraphDef & graph)178 Status GrpcSession::Extend(const GraphDef& graph) {
179 return Extend(GraphDef(graph));
180 }
181
Extend(const RunOptions & run_options,const GraphDef & graph)182 Status GrpcSession::Extend(const RunOptions& run_options,
183 const GraphDef& graph) {
184 return Extend(run_options, GraphDef(graph));
185 }
186
Extend(GraphDef && graph)187 Status GrpcSession::Extend(GraphDef&& graph) {
188 CallOptions call_options;
189 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
190 return ExtendImpl(&call_options, std::move(graph));
191 }
192
Extend(const RunOptions & run_options,GraphDef && graph)193 Status GrpcSession::Extend(const RunOptions& run_options, GraphDef&& graph) {
194 CallOptions call_options;
195 call_options.SetTimeout(run_options.timeout_in_ms());
196 return ExtendImpl(&call_options, std::move(graph));
197 }
198
RunHelper(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const string & prun_handle)199 Status GrpcSession::RunHelper(
200 const RunOptions& run_options,
201 const std::vector<std::pair<string, Tensor>>& inputs,
202 const std::vector<string>& output_tensor_names,
203 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
204 RunMetadata* run_metadata, const string& prun_handle) {
205 // Convert to proto
206 std::unique_ptr<MutableRunStepRequestWrapper> req(
207 master_->CreateRunStepRequest());
208 std::unique_ptr<MutableRunStepResponseWrapper> resp(
209 master_->CreateRunStepResponse());
210
211 *req->mutable_options() = run_options;
212
213 if (run_options.timeout_in_ms() == 0) {
214 req->mutable_options()->set_timeout_in_ms(
215 options_.config.operation_timeout_in_ms());
216 }
217
218 if (!prun_handle.empty()) {
219 req->set_partial_run_handle(prun_handle);
220 }
221
222 for (const auto& it : inputs) {
223 req->add_feed(it.first, it.second);
224 }
225
226 // Support long error messages by storing the error code in the response body.
227 req->set_store_errors_in_response_body(true);
228
229 // Build an index from fetch tensor name to first index in
230 // output_tensor_names.
231 std::unordered_map<string, int> output_name_to_offset;
232 for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
233 const string& name = output_tensor_names[i];
234 if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
235 req->add_fetch(name);
236 }
237 }
238 for (const string& target : target_node_names) {
239 req->add_target(target);
240 }
241
242 CallOptions call_options;
243 call_options.SetTimeout(req->options().timeout_in_ms());
244 TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
245
246 // Look for an extended error returned in the response body.
247 if (resp->status_code() != error::Code::OK) {
248 return Status(resp->status_code(), resp->status_error_message());
249 }
250
251 if (!output_tensor_names.empty()) {
252 outputs->resize(output_tensor_names.size());
253 }
254
255 // Convert response back to Tensors in the correct order.
256 for (size_t i = 0; i < resp->num_tensors(); ++i) {
257 auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
258 if (fetch_it == output_name_to_offset.end()) {
259 return errors::Internal("Received response for unrequested fetch: ",
260 resp->tensor_name(i));
261 }
262
263 Tensor output;
264 TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
265 (*outputs)[fetch_it->second] = output;
266 }
267 // In the unlikely event that output_tensor_names contains duplicates, fill in
268 // the duplicate values.
269 if (output_name_to_offset.size() != output_tensor_names.size()) {
270 for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
271 const string& name = output_tensor_names[i];
272 int offset = output_name_to_offset[name];
273 if (offset != i) {
274 (*outputs)[i] = (*outputs)[offset];
275 }
276 }
277 }
278
279 if (run_metadata) {
280 run_metadata->Swap(resp->mutable_metadata());
281 }
282
283 return Status::OK();
284 }
285
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)286 Status GrpcSession::Run(const RunOptions& run_options,
287 const std::vector<std::pair<string, Tensor>>& inputs,
288 const std::vector<string>& output_tensor_names,
289 const std::vector<string>& target_node_names,
290 std::vector<Tensor>* outputs,
291 RunMetadata* run_metadata) {
292 return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
293 outputs, run_metadata, /* prun_handle */ "");
294 }
295
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)296 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
297 const std::vector<string>& output_tensor_names,
298 const std::vector<string>& target_node_names,
299 std::vector<Tensor>* outputs) {
300 RunOptions run_options;
301 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
302 return Run(run_options, inputs, output_tensor_names, target_node_names,
303 outputs, nullptr);
304 }
305
RunProto(CallOptions * call_options,MutableRunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp)306 Status GrpcSession::RunProto(CallOptions* call_options,
307 MutableRunStepRequestWrapper* req,
308 MutableRunStepResponseWrapper* resp) {
309 string handle;
310 TF_RETURN_IF_ERROR(Handle(&handle));
311 req->set_session_handle(handle);
312 return master_->RunStep(call_options, req, resp);
313 }
314
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)315 Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
316 const std::vector<string>& output_names,
317 const std::vector<string>& target_nodes,
318 string* handle) {
319 // Convert to proto
320 PartialRunSetupRequest req;
321 PartialRunSetupResponse resp;
322 CallOptions call_options;
323 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
324 for (const string& feed : input_names) {
325 req.add_feed(feed);
326 }
327 for (const string& fetch : output_names) {
328 req.add_fetch(fetch);
329 }
330 for (const string& target : target_nodes) {
331 req.add_target(target);
332 }
333 if (!is_local_) req.set_request_id(GetUniqueRequestId());
334 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
335 TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
336 *handle = resp.partial_run_handle();
337 return Status::OK();
338 }
339
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)340 Status GrpcSession::PRun(const string& handle,
341 const std::vector<std::pair<string, Tensor>>& inputs,
342 const std::vector<string>& output_names,
343 std::vector<Tensor>* outputs) {
344 RunOptions run_options;
345 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
346 return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
347 /* run_metadata */ nullptr, handle);
348 }
349
Close()350 Status GrpcSession::Close() {
351 CloseSessionRequest req;
352 {
353 mutex_lock l(mu_);
354 if (handle_.empty()) {
355 return Status::OK();
356 }
357 req.set_session_handle(handle_);
358 handle_.clear();
359 }
360 CloseSessionResponse resp;
361 CallOptions call_options;
362 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
363 return master_->CloseSession(&call_options, &req, &resp);
364 }
365
ListDevices(std::vector<DeviceAttributes> * response)366 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
367 ListDevicesRequest req;
368 {
369 mutex_lock l(mu_);
370 req.set_session_handle(handle_);
371 }
372 if (req.session_handle().empty()) {
373 LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
374 "an empty graph and other defaults because the session has "
375 "not yet been created.";
376 GraphDef graph_def;
377 TF_RETURN_IF_ERROR(Create(graph_def));
378 {
379 mutex_lock l(mu_);
380 req.set_session_handle(handle_);
381 }
382 }
383 ListDevicesResponse resp;
384 CallOptions call_options;
385 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
386 Status s = master_->ListDevices(&call_options, &req, &resp);
387 if (!s.ok()) {
388 LOG(ERROR) << "Could not list devices: " << s;
389 return s;
390 }
391
392 response->clear();
393 response->reserve(resp.local_device_size() + resp.remote_device_size());
394 for (const auto& device_attr : resp.local_device()) {
395 response->emplace_back(device_attr);
396 }
397 for (const auto& device_attr : resp.remote_device()) {
398 response->emplace_back(device_attr);
399 }
400 return Status::OK();
401 }
402
SetRemoteMaster(std::unique_ptr<MasterInterface> master)403 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
404 master_ = std::move(master);
405 }
406
407 // Static method.
Reset(const SessionOptions & options,const std::vector<string> & containers)408 Status GrpcSession::Reset(const SessionOptions& options,
409 const std::vector<string>& containers) {
410 SharedGrpcChannelPtr master_channel;
411 TF_RETURN_IF_ERROR(
412 NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
413 /*rpc_options=*/nullptr, &master_channel));
414 auto master = NewGrpcMaster(master_channel);
415 ResetRequest req;
416 req.mutable_container()->Reserve(containers.size());
417 for (const auto& c : containers) req.add_container(c);
418 ResetResponse resp;
419 CallOptions call_options;
420 call_options.SetTimeout(options.config.operation_timeout_in_ms());
421 Status ret = master->Reset(&call_options, &req, &resp);
422 delete master;
423 return ret;
424 }
425
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)426 Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
427 CallableHandle* out_handle) {
428 MakeCallableRequest req;
429 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
430 *req.mutable_options() = callable_options;
431 if (!is_local_) req.set_request_id(GetUniqueRequestId());
432 MakeCallableResponse resp;
433 CallOptions call_options;
434 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
435 TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
436 *out_handle = resp.handle();
437 return Status::OK();
438 }
439
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)440 Status GrpcSession::RunCallable(CallableHandle handle,
441 const std::vector<Tensor>& feed_tensors,
442 std::vector<Tensor>* fetch_tensors,
443 RunMetadata* run_metadata) {
444 RunCallableRequest req;
445 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
446 req.set_handle(handle);
447 if (!is_local_) req.set_request_id(GetUniqueRequestId());
448 for (const Tensor& feed : feed_tensors) {
449 feed.AsProtoTensorContent(req.mutable_feed()->Add());
450 }
451
452 RunCallableResponse resp;
453 CallOptions call_options;
454 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
455 TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
456 for (const TensorProto& fetch : resp.fetch()) {
457 Tensor fetch_tensor;
458 if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
459 return errors::Internal(
460 "Could not parse fetched tensor data in response from master.");
461 }
462 fetch_tensors->push_back(std::move(fetch_tensor));
463 }
464 return Status::OK();
465 }
466
ReleaseCallable(CallableHandle handle)467 Status GrpcSession::ReleaseCallable(CallableHandle handle) {
468 ReleaseCallableRequest req;
469 TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
470 req.set_handle(handle);
471 ReleaseCallableResponse resp;
472 CallOptions call_options;
473 call_options.SetTimeout(options_.config.operation_timeout_in_ms());
474 return master_->ReleaseCallable(&call_options, &req, &resp);
475 }
476
477 class GrpcSessionFactory : public SessionFactory {
478 public:
AcceptsOptions(const SessionOptions & options)479 bool AcceptsOptions(const SessionOptions& options) override {
480 return absl::StartsWith(options.target, kSchemePrefix);
481 }
482
NewSession(const SessionOptions & options,Session ** out_session)483 Status NewSession(const SessionOptions& options,
484 Session** out_session) override {
485 std::unique_ptr<GrpcSession> session;
486 TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
487 *out_session = session.release();
488 return Status::OK();
489 }
490
491 // Invokes the session specific static method to reset containers.
Reset(const SessionOptions & options,const std::vector<string> & containers)492 Status Reset(const SessionOptions& options,
493 const std::vector<string>& containers) override {
494 return GrpcSession::Reset(options, containers);
495 }
496 };
497
498 class GrpcSessionRegistrar {
499 public:
GrpcSessionRegistrar()500 GrpcSessionRegistrar() {
501 SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
502 }
503 };
504 static GrpcSessionRegistrar registrar;
505
506 } // namespace tensorflow
507