• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/common_runtime/session_factory.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/local_master.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/protobuf/master.pb.h"
34 
35 namespace tensorflow {
36 
37 const char* const kSchemePrefix = "grpc://";
38 const size_t kSchemePrefixLength = strlen(kSchemePrefix);
39 
GrpcSession(const SessionOptions & options)40 GrpcSession::GrpcSession(const SessionOptions& options)
41     : options_(options), current_graph_version_(-1) {}
42 
~GrpcSession()43 GrpcSession::~GrpcSession() {}
44 
45 /* static */
Create(const SessionOptions & options,std::unique_ptr<GrpcSession> * out_session)46 Status GrpcSession::Create(const SessionOptions& options,
47                            std::unique_ptr<GrpcSession>* out_session) {
48   std::unique_ptr<GrpcSession> session(new GrpcSession(options));
49   std::unique_ptr<MasterInterface> master;
50   // For testing, we enable the client to disable the use of the local
51   // master registry, so that the RPC stack is exercised.
52   if (!options.config.rpc_options().use_rpc_for_inprocess_master()) {
53     master = LocalMaster::Lookup(options.target);
54   }
55   if (!master) {
56     SharedGrpcChannelPtr master_channel;
57     TF_RETURN_IF_ERROR(
58         NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
59                                &options.config.rpc_options(), &master_channel));
60     master.reset(NewGrpcMaster(master_channel));
61   } else {
62     session->is_local_ = true;
63   }
64   session->SetRemoteMaster(std::move(master));
65   *out_session = std::move(session);
66   return Status::OK();
67 }
68 
69 namespace {
70 // Re-encodes constant represented in tensor proto into
71 // tensor_content, which is slightly better (less copies and lower peak
72 // memory usage) when used with rpc subsystems.
ReEncodeConsts(GraphDef * gdef)73 void ReEncodeConsts(GraphDef* gdef) {
74   for (NodeDef& ndef : *(gdef->mutable_node())) {
75     if (ndef.op() == "Const") {
76       TensorProto* proto = nullptr;
77       for (auto& attr : *ndef.mutable_attr()) {
78         if (attr.first == "value") {
79           proto = attr.second.mutable_tensor();
80         }
81       }
82       if (proto != nullptr && proto->tensor_content().empty() &&
83           proto->ByteSizeLong() > 64) {
84         // If the constant is encoded with repeated proto fields and
85         // it is moderate large, we re-encode it in tensor_content as
86         // a Cord. This is mildly helpful for reducing the peak memory
87         // usage on the server side where GraphDef/NodeDef are copied
88         // quite often.
89         Tensor parsed(proto->dtype());
90         if (parsed.FromProto(*proto)) {
91           parsed.AsProtoTensorContent(proto);
92         }
93       }
94     }
95   }
96 }
97 }  // namespace
98 
SetHandleAndGraphVersion(string handle,int64 graph_version)99 void GrpcSession::SetHandleAndGraphVersion(string handle, int64 graph_version) {
100   mutex_lock l(mu_);
101   handle_ = std::move(handle);
102   current_graph_version_ = graph_version;
103 }
104 
Handle(string * out_handle)105 Status GrpcSession::Handle(string* out_handle) {
106   mutex_lock l(mu_);
107   if (handle_.empty()) {
108     return errors::InvalidArgument("A session is not created yet....");
109   }
110   *out_handle = handle_;
111   return Status::OK();
112 }
113 
CreateImpl(CallOptions * call_options,GraphDef graph)114 Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) {
115   {
116     mutex_lock l(mu_);
117     if (!handle_.empty()) {
118       return errors::InvalidArgument("A session is alive.");
119     }
120   }
121   CreateSessionRequest req;
122   *req.mutable_config() = options_.config;
123   req.mutable_graph_def()->Swap(&graph);
124   req.set_target(options_.target);
125   ReEncodeConsts(req.mutable_graph_def());
126   CreateSessionResponse resp;
127   Status s = master_->CreateSession(call_options, &req, &resp);
128   if (s.ok()) {
129     SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version());
130   }
131   return s;
132 }
133 
Create(const GraphDef & graph)134 Status GrpcSession::Create(const GraphDef& graph) {
135   return Create(GraphDef(graph));
136 }
137 
Create(const RunOptions & run_options,const GraphDef & graph)138 Status GrpcSession::Create(const RunOptions& run_options,
139                            const GraphDef& graph) {
140   return Create(run_options, GraphDef(graph));
141 }
142 
Create(GraphDef && graph)143 Status GrpcSession::Create(GraphDef&& graph) {
144   CallOptions call_options;
145   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
146   return CreateImpl(&call_options, std::move(graph));
147 }
148 
Create(const RunOptions & run_options,GraphDef && graph)149 Status GrpcSession::Create(const RunOptions& run_options, GraphDef&& graph) {
150   CallOptions call_options;
151   call_options.SetTimeout(run_options.timeout_in_ms());
152   return CreateImpl(&call_options, std::move(graph));
153 }
154 
ExtendImpl(CallOptions * call_options,GraphDef graph)155 Status GrpcSession::ExtendImpl(CallOptions* call_options, GraphDef graph) {
156   bool handle_is_empty;
157   {
158     mutex_lock l(mu_);
159     handle_is_empty = handle_.empty();
160   }
161   if (handle_is_empty) {
162     // Session was uninitialized, so simply initialize the session with 'graph'.
163     return Create(std::move(graph));
164   }
165   mutex_lock l(mu_);
166   ExtendSessionRequest req;
167   req.set_session_handle(handle_);
168   req.mutable_graph_def()->Swap(&graph);
169   req.set_current_graph_version(current_graph_version_);
170   ExtendSessionResponse resp;
171   Status s = master_->ExtendSession(call_options, &req, &resp);
172   if (s.ok()) {
173     current_graph_version_ = resp.new_graph_version();
174   }
175   return s;
176 }
177 
Extend(const GraphDef & graph)178 Status GrpcSession::Extend(const GraphDef& graph) {
179   return Extend(GraphDef(graph));
180 }
181 
Extend(const RunOptions & run_options,const GraphDef & graph)182 Status GrpcSession::Extend(const RunOptions& run_options,
183                            const GraphDef& graph) {
184   return Extend(run_options, GraphDef(graph));
185 }
186 
Extend(GraphDef && graph)187 Status GrpcSession::Extend(GraphDef&& graph) {
188   CallOptions call_options;
189   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
190   return ExtendImpl(&call_options, std::move(graph));
191 }
192 
Extend(const RunOptions & run_options,GraphDef && graph)193 Status GrpcSession::Extend(const RunOptions& run_options, GraphDef&& graph) {
194   CallOptions call_options;
195   call_options.SetTimeout(run_options.timeout_in_ms());
196   return ExtendImpl(&call_options, std::move(graph));
197 }
198 
RunHelper(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const string & prun_handle)199 Status GrpcSession::RunHelper(
200     const RunOptions& run_options,
201     const std::vector<std::pair<string, Tensor>>& inputs,
202     const std::vector<string>& output_tensor_names,
203     const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
204     RunMetadata* run_metadata, const string& prun_handle) {
205   // Convert to proto
206   std::unique_ptr<MutableRunStepRequestWrapper> req(
207       master_->CreateRunStepRequest());
208   std::unique_ptr<MutableRunStepResponseWrapper> resp(
209       master_->CreateRunStepResponse());
210 
211   *req->mutable_options() = run_options;
212 
213   if (run_options.timeout_in_ms() == 0) {
214     req->mutable_options()->set_timeout_in_ms(
215         options_.config.operation_timeout_in_ms());
216   }
217 
218   if (!prun_handle.empty()) {
219     req->set_partial_run_handle(prun_handle);
220   }
221 
222   for (const auto& it : inputs) {
223     req->add_feed(it.first, it.second);
224   }
225 
226   // Support long error messages by storing the error code in the response body.
227   req->set_store_errors_in_response_body(true);
228 
229   // Build an index from fetch tensor name to first index in
230   // output_tensor_names.
231   std::unordered_map<string, int> output_name_to_offset;
232   for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
233     const string& name = output_tensor_names[i];
234     if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
235       req->add_fetch(name);
236     }
237   }
238   for (const string& target : target_node_names) {
239     req->add_target(target);
240   }
241 
242   CallOptions call_options;
243   call_options.SetTimeout(req->options().timeout_in_ms());
244   TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
245 
246   // Look for an extended error returned in the response body.
247   if (resp->status_code() != error::Code::OK) {
248     return Status(resp->status_code(), resp->status_error_message());
249   }
250 
251   if (!output_tensor_names.empty()) {
252     outputs->resize(output_tensor_names.size());
253   }
254 
255   // Convert response back to Tensors in the correct order.
256   for (size_t i = 0; i < resp->num_tensors(); ++i) {
257     auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
258     if (fetch_it == output_name_to_offset.end()) {
259       return errors::Internal("Received response for unrequested fetch: ",
260                               resp->tensor_name(i));
261     }
262 
263     Tensor output;
264     TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
265     (*outputs)[fetch_it->second] = output;
266   }
267   // In the unlikely event that output_tensor_names contains duplicates, fill in
268   // the duplicate values.
269   if (output_name_to_offset.size() != output_tensor_names.size()) {
270     for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
271       const string& name = output_tensor_names[i];
272       int offset = output_name_to_offset[name];
273       if (offset != i) {
274         (*outputs)[i] = (*outputs)[offset];
275       }
276     }
277   }
278 
279   if (run_metadata) {
280     run_metadata->Swap(resp->mutable_metadata());
281   }
282 
283   return Status::OK();
284 }
285 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)286 Status GrpcSession::Run(const RunOptions& run_options,
287                         const std::vector<std::pair<string, Tensor>>& inputs,
288                         const std::vector<string>& output_tensor_names,
289                         const std::vector<string>& target_node_names,
290                         std::vector<Tensor>* outputs,
291                         RunMetadata* run_metadata) {
292   return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
293                    outputs, run_metadata, /* prun_handle */ "");
294 }
295 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)296 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
297                         const std::vector<string>& output_tensor_names,
298                         const std::vector<string>& target_node_names,
299                         std::vector<Tensor>* outputs) {
300   RunOptions run_options;
301   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
302   return Run(run_options, inputs, output_tensor_names, target_node_names,
303              outputs, nullptr);
304 }
305 
RunProto(CallOptions * call_options,MutableRunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp)306 Status GrpcSession::RunProto(CallOptions* call_options,
307                              MutableRunStepRequestWrapper* req,
308                              MutableRunStepResponseWrapper* resp) {
309   string handle;
310   TF_RETURN_IF_ERROR(Handle(&handle));
311   req->set_session_handle(handle);
312   return master_->RunStep(call_options, req, resp);
313 }
314 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)315 Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
316                               const std::vector<string>& output_names,
317                               const std::vector<string>& target_nodes,
318                               string* handle) {
319   // Convert to proto
320   PartialRunSetupRequest req;
321   PartialRunSetupResponse resp;
322   CallOptions call_options;
323   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
324   for (const string& feed : input_names) {
325     req.add_feed(feed);
326   }
327   for (const string& fetch : output_names) {
328     req.add_fetch(fetch);
329   }
330   for (const string& target : target_nodes) {
331     req.add_target(target);
332   }
333   if (!is_local_) req.set_request_id(GetUniqueRequestId());
334   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
335   TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
336   *handle = resp.partial_run_handle();
337   return Status::OK();
338 }
339 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)340 Status GrpcSession::PRun(const string& handle,
341                          const std::vector<std::pair<string, Tensor>>& inputs,
342                          const std::vector<string>& output_names,
343                          std::vector<Tensor>* outputs) {
344   RunOptions run_options;
345   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
346   return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
347                    /* run_metadata */ nullptr, handle);
348 }
349 
Close()350 Status GrpcSession::Close() {
351   CloseSessionRequest req;
352   {
353     mutex_lock l(mu_);
354     if (handle_.empty()) {
355       return Status::OK();
356     }
357     req.set_session_handle(handle_);
358     handle_.clear();
359   }
360   CloseSessionResponse resp;
361   CallOptions call_options;
362   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
363   return master_->CloseSession(&call_options, &req, &resp);
364 }
365 
ListDevices(std::vector<DeviceAttributes> * response)366 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
367   ListDevicesRequest req;
368   {
369     mutex_lock l(mu_);
370     req.set_session_handle(handle_);
371   }
372   if (req.session_handle().empty()) {
373     LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
374                     "an empty graph and other defaults because the session has "
375                     "not yet been created.";
376     GraphDef graph_def;
377     TF_RETURN_IF_ERROR(Create(graph_def));
378     {
379       mutex_lock l(mu_);
380       req.set_session_handle(handle_);
381     }
382   }
383   ListDevicesResponse resp;
384   CallOptions call_options;
385   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
386   Status s = master_->ListDevices(&call_options, &req, &resp);
387   if (!s.ok()) {
388     LOG(ERROR) << "Could not list devices: " << s;
389     return s;
390   }
391 
392   response->clear();
393   response->reserve(resp.local_device_size() + resp.remote_device_size());
394   for (const auto& device_attr : resp.local_device()) {
395     response->emplace_back(device_attr);
396   }
397   for (const auto& device_attr : resp.remote_device()) {
398     response->emplace_back(device_attr);
399   }
400   return Status::OK();
401 }
402 
SetRemoteMaster(std::unique_ptr<MasterInterface> master)403 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
404   master_ = std::move(master);
405 }
406 
407 // Static method.
Reset(const SessionOptions & options,const std::vector<string> & containers)408 Status GrpcSession::Reset(const SessionOptions& options,
409                           const std::vector<string>& containers) {
410   SharedGrpcChannelPtr master_channel;
411   TF_RETURN_IF_ERROR(
412       NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
413                              /*rpc_options=*/nullptr, &master_channel));
414   auto master = NewGrpcMaster(master_channel);
415   ResetRequest req;
416   req.mutable_container()->Reserve(containers.size());
417   for (const auto& c : containers) req.add_container(c);
418   ResetResponse resp;
419   CallOptions call_options;
420   call_options.SetTimeout(options.config.operation_timeout_in_ms());
421   Status ret = master->Reset(&call_options, &req, &resp);
422   delete master;
423   return ret;
424 }
425 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)426 Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
427                                  CallableHandle* out_handle) {
428   MakeCallableRequest req;
429   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
430   *req.mutable_options() = callable_options;
431   if (!is_local_) req.set_request_id(GetUniqueRequestId());
432   MakeCallableResponse resp;
433   CallOptions call_options;
434   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
435   TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
436   *out_handle = resp.handle();
437   return Status::OK();
438 }
439 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)440 Status GrpcSession::RunCallable(CallableHandle handle,
441                                 const std::vector<Tensor>& feed_tensors,
442                                 std::vector<Tensor>* fetch_tensors,
443                                 RunMetadata* run_metadata) {
444   RunCallableRequest req;
445   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
446   req.set_handle(handle);
447   if (!is_local_) req.set_request_id(GetUniqueRequestId());
448   for (const Tensor& feed : feed_tensors) {
449     feed.AsProtoTensorContent(req.mutable_feed()->Add());
450   }
451 
452   RunCallableResponse resp;
453   CallOptions call_options;
454   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
455   TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
456   for (const TensorProto& fetch : resp.fetch()) {
457     Tensor fetch_tensor;
458     if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
459       return errors::Internal(
460           "Could not parse fetched tensor data in response from master.");
461     }
462     fetch_tensors->push_back(std::move(fetch_tensor));
463   }
464   return Status::OK();
465 }
466 
ReleaseCallable(CallableHandle handle)467 Status GrpcSession::ReleaseCallable(CallableHandle handle) {
468   ReleaseCallableRequest req;
469   TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
470   req.set_handle(handle);
471   ReleaseCallableResponse resp;
472   CallOptions call_options;
473   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
474   return master_->ReleaseCallable(&call_options, &req, &resp);
475 }
476 
477 class GrpcSessionFactory : public SessionFactory {
478  public:
AcceptsOptions(const SessionOptions & options)479   bool AcceptsOptions(const SessionOptions& options) override {
480     return absl::StartsWith(options.target, kSchemePrefix);
481   }
482 
NewSession(const SessionOptions & options,Session ** out_session)483   Status NewSession(const SessionOptions& options,
484                     Session** out_session) override {
485     std::unique_ptr<GrpcSession> session;
486     TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
487     *out_session = session.release();
488     return Status::OK();
489   }
490 
491   // Invokes the session specific static method to reset containers.
Reset(const SessionOptions & options,const std::vector<string> & containers)492   Status Reset(const SessionOptions& options,
493                const std::vector<string>& containers) override {
494     return GrpcSession::Reset(options, containers);
495   }
496 };
497 
498 class GrpcSessionRegistrar {
499  public:
GrpcSessionRegistrar()500   GrpcSessionRegistrar() {
501     SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
502   }
503 };
504 static GrpcSessionRegistrar registrar;
505 
506 }  // namespace tensorflow
507