1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
17
18 #include "tensorflow/core/framework/cost_graph.pb.h"
19 #include "tensorflow/core/framework/step_stats.pb.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/protobuf/config.pb.h"
22 #include "tensorflow/core/protobuf/named_tensor.pb.h"
23
24 namespace tensorflow {
25
26 namespace {
27
ParseTensorProtoToTensor(const TensorProto & tensor_proto,Tensor * out_tensor)28 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
29 Tensor* out_tensor) {
30 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
31 Tensor parsed(tensor_proto.dtype());
32 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
33 *out_tensor = parsed;
34 return true;
35 }
36 }
37 return false;
38 }
39
40 } // namespace
41
session_handle() const42 const string& InMemoryRunStepRequest::session_handle() const {
43 return session_handle_;
44 }
45
set_session_handle(const string & handle)46 void InMemoryRunStepRequest::set_session_handle(const string& handle) {
47 session_handle_ = handle;
48 }
49
partial_run_handle() const50 const string& InMemoryRunStepRequest::partial_run_handle() const {
51 return partial_run_handle_;
52 }
53
set_partial_run_handle(const string & handle)54 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
55 partial_run_handle_ = handle;
56 }
57
num_feeds() const58 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
feed_name(size_t i) const59 const string& InMemoryRunStepRequest::feed_name(size_t i) const {
60 return feeds_[i].first;
61 }
62
FeedValue(size_t i,Tensor * out_tensor) const63 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
64 *out_tensor = feeds_[i].second;
65 return Status::OK();
66 }
67
FeedValue(size_t i,TensorProto * out_tensor) const68 Status InMemoryRunStepRequest::FeedValue(size_t i,
69 TensorProto* out_tensor) const {
70 feeds_[i].second.AsProtoTensorContent(out_tensor);
71 return Status::OK();
72 }
73
add_feed(const string & name,const Tensor & value)74 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
75 feeds_.emplace_back(name, value);
76 }
77
num_fetches() const78 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
fetch_name(size_t i) const79 const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
80 return fetches_[i];
81 }
add_fetch(const string & name)82 void InMemoryRunStepRequest::add_fetch(const string& name) {
83 fetches_.push_back(name);
84 }
85
num_targets() const86 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
target_name(size_t i) const87 const string& InMemoryRunStepRequest::target_name(size_t i) const {
88 return targets_[i];
89 }
add_target(const string & name)90 void InMemoryRunStepRequest::add_target(const string& name) {
91 targets_.push_back(name);
92 }
93
options() const94 const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
95
mutable_options()96 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
97
store_errors_in_response_body() const98 bool InMemoryRunStepRequest::store_errors_in_response_body() const {
99 return store_errors_in_response_body_;
100 }
101
request_id() const102 int64 InMemoryRunStepRequest::request_id() const {
103 return 0; // no need to track request id for local version.
104 }
105
set_store_errors_in_response_body(bool store_errors)106 void InMemoryRunStepRequest::set_store_errors_in_response_body(
107 bool store_errors) {
108 store_errors_in_response_body_ = store_errors;
109 }
110
DebugString() const111 string InMemoryRunStepRequest::DebugString() const {
112 return ToProto().DebugString();
113 }
114
ToProto() const115 const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
116 if (!proto_version_) {
117 proto_version_.reset(new RunStepRequest);
118 proto_version_->set_session_handle(session_handle());
119 proto_version_->set_partial_run_handle(partial_run_handle());
120 for (size_t i = 0; i < num_feeds(); ++i) {
121 auto feed = proto_version_->add_feed();
122 feed->set_name(feed_name(i));
123 feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
124 }
125 for (size_t i = 0; i < num_fetches(); ++i) {
126 proto_version_->add_fetch(fetch_name(i));
127 }
128 for (size_t i = 0; i < num_targets(); ++i) {
129 proto_version_->add_target(target_name(i));
130 }
131 *proto_version_->mutable_options() = options();
132 }
133 return *proto_version_;
134 }
135
session_handle() const136 const string& MutableProtoRunStepRequest::session_handle() const {
137 return request_.session_handle();
138 }
set_session_handle(const string & handle)139 void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
140 request_.set_session_handle(handle);
141 }
142
partial_run_handle() const143 const string& MutableProtoRunStepRequest::partial_run_handle() const {
144 return request_.partial_run_handle();
145 }
set_partial_run_handle(const string & handle)146 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
147 request_.set_partial_run_handle(handle);
148 }
149
num_feeds() const150 size_t MutableProtoRunStepRequest::num_feeds() const {
151 return request_.feed_size();
152 }
feed_name(size_t i) const153 const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
154 return request_.feed(i).name();
155 }
FeedValue(size_t i,Tensor * out_tensor) const156 Status MutableProtoRunStepRequest::FeedValue(size_t i,
157 Tensor* out_tensor) const {
158 if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
159 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
160 } else {
161 return Status::OK();
162 }
163 }
164
FeedValue(size_t i,TensorProto * out_tensor) const165 Status MutableProtoRunStepRequest::FeedValue(size_t i,
166 TensorProto* out_tensor) const {
167 *out_tensor = request_.feed(i).tensor();
168 return Status::OK();
169 }
170
add_feed(const string & name,const Tensor & value)171 void MutableProtoRunStepRequest::add_feed(const string& name,
172 const Tensor& value) {
173 NamedTensorProto* feed = request_.add_feed();
174 feed->set_name(name);
175 TensorProto* value_proto = feed->mutable_tensor();
176 value.AsProtoTensorContent(value_proto);
177 }
178
num_fetches() const179 size_t MutableProtoRunStepRequest::num_fetches() const {
180 return request_.fetch_size();
181 }
182
fetch_name(size_t i) const183 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
184 return request_.fetch(i);
185 }
add_fetch(const string & name)186 void MutableProtoRunStepRequest::add_fetch(const string& name) {
187 request_.add_fetch(name);
188 }
189
num_targets() const190 size_t MutableProtoRunStepRequest::num_targets() const {
191 return request_.target_size();
192 }
193
target_name(size_t i) const194 const string& MutableProtoRunStepRequest::target_name(size_t i) const {
195 return request_.target(i);
196 }
197
add_target(const string & name)198 void MutableProtoRunStepRequest::add_target(const string& name) {
199 request_.add_target(name);
200 }
201
options() const202 const RunOptions& MutableProtoRunStepRequest::options() const {
203 return request_.options();
204 }
205
mutable_options()206 RunOptions* MutableProtoRunStepRequest::mutable_options() {
207 return request_.mutable_options();
208 }
209
store_errors_in_response_body() const210 bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
211 return request_.store_errors_in_response_body();
212 }
213
set_store_errors_in_response_body(bool store_errors)214 void MutableProtoRunStepRequest::set_store_errors_in_response_body(
215 bool store_errors) {
216 request_.set_store_errors_in_response_body(store_errors);
217 }
218
request_id() const219 int64 MutableProtoRunStepRequest::request_id() const {
220 return request_.request_id();
221 }
222
DebugString() const223 string MutableProtoRunStepRequest::DebugString() const {
224 return request_.DebugString();
225 }
226
ToProto() const227 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
228 return request_;
229 }
230
ProtoRunStepRequest(const RunStepRequest * request)231 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
232 : request_(request) {}
233
session_handle() const234 const string& ProtoRunStepRequest::session_handle() const {
235 return request_->session_handle();
236 }
237
partial_run_handle() const238 const string& ProtoRunStepRequest::partial_run_handle() const {
239 return request_->partial_run_handle();
240 }
241
num_feeds() const242 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
243
feed_name(size_t i) const244 const string& ProtoRunStepRequest::feed_name(size_t i) const {
245 return request_->feed(i).name();
246 }
247
FeedValue(size_t i,Tensor * out_tensor) const248 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
249 if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
250 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
251 } else {
252 return Status::OK();
253 }
254 }
255
FeedValue(size_t i,TensorProto * out_tensor) const256 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
257 *out_tensor = request_->feed(i).tensor();
258 return Status::OK();
259 }
260
num_fetches() const261 size_t ProtoRunStepRequest::num_fetches() const {
262 return request_->fetch_size();
263 }
264
fetch_name(size_t i) const265 const string& ProtoRunStepRequest::fetch_name(size_t i) const {
266 return request_->fetch(i);
267 }
268
num_targets() const269 size_t ProtoRunStepRequest::num_targets() const {
270 return request_->target_size();
271 }
272
target_name(size_t i) const273 const string& ProtoRunStepRequest::target_name(size_t i) const {
274 return request_->target(i);
275 }
276
options() const277 const RunOptions& ProtoRunStepRequest::options() const {
278 return request_->options();
279 }
280
store_errors_in_response_body() const281 bool ProtoRunStepRequest::store_errors_in_response_body() const {
282 return request_->store_errors_in_response_body();
283 }
284
request_id() const285 int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
286
DebugString() const287 string ProtoRunStepRequest::DebugString() const {
288 return request_->DebugString();
289 }
290
ToProto() const291 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
292
session_handle() const293 const string& InMemoryRunGraphRequest::session_handle() const {
294 return session_handle_;
295 }
296
create_worker_session_called() const297 bool InMemoryRunGraphRequest::create_worker_session_called() const {
298 return create_worker_session_called_;
299 }
300
set_session_handle(const string & handle)301 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
302 session_handle_ = handle;
303 }
304
set_create_worker_session_called(bool called)305 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
306 create_worker_session_called_ = called;
307 }
308
graph_handle() const309 const string& InMemoryRunGraphRequest::graph_handle() const {
310 return graph_handle_;
311 }
312
set_graph_handle(const string & handle)313 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
314 graph_handle_ = handle;
315 }
316
step_id() const317 int64 InMemoryRunGraphRequest::step_id() const { return step_id_; }
318
set_step_id(int64 step_id)319 void InMemoryRunGraphRequest::set_step_id(int64 step_id) { step_id_ = step_id; }
320
exec_opts() const321 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
322 return exec_opts_;
323 }
324
mutable_exec_opts()325 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
326 return &exec_opts_;
327 }
328
num_sends() const329 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
330
send_key(size_t i) const331 const string& InMemoryRunGraphRequest::send_key(size_t i) const {
332 return sends_[i].first;
333 }
334
SendValue(size_t i,Tensor * out_tensor) const335 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
336 *out_tensor = sends_[i].second;
337 return Status::OK();
338 }
339
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)340 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
341 const RunStepRequestWrapper& run_step_request, size_t i,
342 const string& send_key) {
343 Tensor tensor;
344 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
345 sends_.emplace_back(send_key, std::move(tensor));
346 return Status::OK();
347 }
348
349 // TODO(b/74355905): Add a specialized implementation that avoids
350 // copying the tensor when at least two of the {client, master,
351 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)352 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
353 const RunCallableRequest& run_callable_request, size_t i,
354 const string& send_key) {
355 Tensor tensor;
356 if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
357 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
358 }
359 sends_.emplace_back(send_key, std::move(tensor));
360 return Status::OK();
361 }
362
num_recvs() const363 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
364
recv_key(size_t i) const365 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
366 return recvs_[i];
367 }
368
add_recv_key(const string & recv_key)369 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
370 recvs_.push_back(recv_key);
371 }
372
is_partial() const373 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
374
set_is_partial(bool is_partial)375 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
376 is_partial_ = is_partial;
377 }
378
is_last_partial_run() const379 bool InMemoryRunGraphRequest::is_last_partial_run() const {
380 return is_last_partial_run_;
381 }
382
set_is_last_partial_run(bool is_last_partial_run)383 void InMemoryRunGraphRequest::set_is_last_partial_run(
384 bool is_last_partial_run) {
385 is_last_partial_run_ = is_last_partial_run;
386 }
387
store_errors_in_response_body() const388 bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
389 return store_errors_in_response_body_;
390 }
391
set_store_errors_in_response_body(bool store_errors)392 void InMemoryRunGraphRequest::set_store_errors_in_response_body(
393 bool store_errors) {
394 store_errors_in_response_body_ = store_errors;
395 }
396
request_id() const397 int64 InMemoryRunGraphRequest::request_id() const { return request_id_; }
398
set_request_id(int64 request_id)399 void InMemoryRunGraphRequest::set_request_id(int64 request_id) {
400 request_id_ = request_id;
401 }
402
ToProto() const403 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
404 if (!proto_version_) {
405 proto_version_.reset(new RunGraphRequest);
406 proto_version_->set_session_handle(session_handle());
407 proto_version_->set_create_worker_session_called(
408 create_worker_session_called());
409 proto_version_->set_graph_handle(graph_handle());
410 proto_version_->set_step_id(step_id());
411 *proto_version_->mutable_exec_opts() = exec_opts();
412 for (size_t i = 0; i < num_sends(); ++i) {
413 auto send = proto_version_->add_send();
414 send->set_name(send_key(i));
415 sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
416 }
417 for (size_t i = 0; i < num_recvs(); ++i) {
418 proto_version_->add_recv_key(recv_key(i));
419 }
420 proto_version_->set_is_partial(is_partial());
421 proto_version_->set_is_last_partial_run(is_last_partial_run());
422 }
423 proto_version_->set_store_errors_in_response_body(
424 store_errors_in_response_body_);
425 proto_version_->set_request_id(request_id_);
426 return *proto_version_;
427 }
428
session_handle() const429 const string& MutableProtoRunGraphRequest::session_handle() const {
430 return request_.session_handle();
431 }
432
set_session_handle(const string & handle)433 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
434 request_.set_session_handle(handle);
435 }
436
create_worker_session_called() const437 bool MutableProtoRunGraphRequest::create_worker_session_called() const {
438 return request_.create_worker_session_called();
439 }
440
set_create_worker_session_called(bool called)441 void MutableProtoRunGraphRequest::set_create_worker_session_called(
442 bool called) {
443 request_.set_create_worker_session_called(called);
444 }
445
graph_handle() const446 const string& MutableProtoRunGraphRequest::graph_handle() const {
447 return request_.graph_handle();
448 }
449
set_graph_handle(const string & handle)450 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
451 request_.set_graph_handle(handle);
452 }
453
step_id() const454 int64 MutableProtoRunGraphRequest::step_id() const {
455 return request_.step_id();
456 }
457
set_step_id(int64 step_id)458 void MutableProtoRunGraphRequest::set_step_id(int64 step_id) {
459 request_.set_step_id(step_id);
460 }
461
exec_opts() const462 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
463 return request_.exec_opts();
464 }
465
mutable_exec_opts()466 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
467 return request_.mutable_exec_opts();
468 }
469
num_sends() const470 size_t MutableProtoRunGraphRequest::num_sends() const {
471 return request_.send_size();
472 }
473
send_key(size_t i) const474 const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
475 return request_.send(i).name();
476 }
477
SendValue(size_t i,Tensor * out_tensor) const478 Status MutableProtoRunGraphRequest::SendValue(size_t i,
479 Tensor* out_tensor) const {
480 if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
481 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
482 } else {
483 return Status::OK();
484 }
485 }
486
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)487 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
488 const RunStepRequestWrapper& run_step_request, size_t i,
489 const string& send_key) {
490 NamedTensorProto* send = request_.add_send();
491 send->set_name(send_key);
492 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
493 return Status::OK();
494 }
495
496 // TODO(b/74355905): Add a specialized implementation that avoids
497 // copying the tensor when at least two of the {client, master,
498 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)499 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
500 const RunCallableRequest& run_callable_request, size_t i,
501 const string& send_key) {
502 NamedTensorProto* send = request_.add_send();
503 send->set_name(send_key);
504 *send->mutable_tensor() = run_callable_request.feed(i);
505 return Status::OK();
506 }
507
num_recvs() const508 size_t MutableProtoRunGraphRequest::num_recvs() const {
509 return request_.recv_key_size();
510 }
511
recv_key(size_t i) const512 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
513 return request_.recv_key(i);
514 }
515
add_recv_key(const string & recv_key)516 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
517 request_.add_recv_key(recv_key);
518 }
519
is_partial() const520 bool MutableProtoRunGraphRequest::is_partial() const {
521 return request_.is_partial();
522 }
523
set_is_partial(bool is_partial)524 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
525 request_.set_is_partial(is_partial);
526 }
527
is_last_partial_run() const528 bool MutableProtoRunGraphRequest::is_last_partial_run() const {
529 return request_.is_last_partial_run();
530 }
531
set_is_last_partial_run(bool is_last_partial_run)532 void MutableProtoRunGraphRequest::set_is_last_partial_run(
533 bool is_last_partial_run) {
534 request_.set_is_last_partial_run(is_last_partial_run);
535 }
536
store_errors_in_response_body() const537 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
538 return request_.store_errors_in_response_body();
539 }
540
set_store_errors_in_response_body(bool store_errors)541 void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
542 bool store_errors) {
543 request_.set_store_errors_in_response_body(store_errors);
544 }
545
request_id() const546 int64 MutableProtoRunGraphRequest::request_id() const {
547 return request_.request_id();
548 }
549
set_request_id(int64 request_id)550 void MutableProtoRunGraphRequest::set_request_id(int64 request_id) {
551 request_.set_request_id(request_id);
552 }
553
ToProto() const554 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
555 return request_;
556 }
557
ProtoRunGraphRequest(const RunGraphRequest * request)558 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
559 : request_(request) {}
560
session_handle() const561 const string& ProtoRunGraphRequest::session_handle() const {
562 return request_->session_handle();
563 }
564
create_worker_session_called() const565 bool ProtoRunGraphRequest::create_worker_session_called() const {
566 return request_->create_worker_session_called();
567 }
568
graph_handle() const569 const string& ProtoRunGraphRequest::graph_handle() const {
570 return request_->graph_handle();
571 }
572
step_id() const573 int64 ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
574
exec_opts() const575 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
576 return request_->exec_opts();
577 }
578
num_sends() const579 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
580
send_key(size_t i) const581 const string& ProtoRunGraphRequest::send_key(size_t i) const {
582 return request_->send(i).name();
583 }
584
SendValue(size_t i,Tensor * out_tensor) const585 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
586 if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
587 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
588 } else {
589 return Status::OK();
590 }
591 }
592
num_recvs() const593 size_t ProtoRunGraphRequest::num_recvs() const {
594 return request_->recv_key_size();
595 }
596
recv_key(size_t i) const597 const string& ProtoRunGraphRequest::recv_key(size_t i) const {
598 return request_->recv_key(i);
599 }
600
is_partial() const601 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
602
is_last_partial_run() const603 bool ProtoRunGraphRequest::is_last_partial_run() const {
604 return request_->is_last_partial_run();
605 }
606
store_errors_in_response_body() const607 bool ProtoRunGraphRequest::store_errors_in_response_body() const {
608 return request_->store_errors_in_response_body();
609 }
610
request_id() const611 int64 ProtoRunGraphRequest::request_id() const {
612 return request_->request_id();
613 }
614
ToProto() const615 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
616 return *request_;
617 }
618
num_recvs() const619 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
620
recv_key(size_t i) const621 const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
622 return recvs_[i].first;
623 }
624
RecvValue(size_t i,TensorProto * out_tensor)625 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
626 recvs_[i].second.AsProtoTensorContent(out_tensor);
627 return Status::OK();
628 }
629
RecvValue(size_t i,Tensor * out_tensor)630 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
631 *out_tensor = recvs_[i].second;
632 return Status::OK();
633 }
634
AddRecv(const string & key,const Tensor & value)635 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
636 recvs_.emplace_back(key, value);
637 }
638
mutable_step_stats()639 StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
640 return &step_stats_;
641 }
642
mutable_cost_graph()643 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
644 return &cost_graph_;
645 }
646
status_code() const647 errors::Code InMemoryRunGraphResponse::status_code() const {
648 return status_.code();
649 }
650
status_error_message() const651 const string& InMemoryRunGraphResponse::status_error_message() const {
652 return status_.error_message();
653 }
654
set_status(const Status & status)655 void InMemoryRunGraphResponse::set_status(const Status& status) {
656 status_ = status;
657 }
658
get_proto()659 RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
660 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
661 return nullptr;
662 }
663
num_partition_graphs() const664 size_t InMemoryRunGraphResponse::num_partition_graphs() const {
665 return partition_graphs_.size();
666 }
667
mutable_partition_graph(size_t i)668 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
669 return &partition_graphs_[i];
670 }
671
AddPartitionGraph(const GraphDef & partition_graph)672 void InMemoryRunGraphResponse::AddPartitionGraph(
673 const GraphDef& partition_graph) {
674 partition_graphs_.push_back(partition_graph);
675 }
676
num_recvs() const677 size_t OwnedProtoRunGraphResponse::num_recvs() const {
678 return response_.recv_size();
679 }
680
recv_key(size_t i) const681 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
682 return response_.recv(i).name();
683 }
684
RecvValue(size_t i,TensorProto * out_tensor)685 Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
686 TensorProto* out_tensor) {
687 out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
688 return Status::OK();
689 }
690
RecvValue(size_t i,Tensor * out_tensor)691 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
692 if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
693 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
694 } else {
695 return Status::OK();
696 }
697 }
698
AddRecv(const string & key,const Tensor & value)699 void OwnedProtoRunGraphResponse::AddRecv(const string& key,
700 const Tensor& value) {
701 NamedTensorProto* recv = response_.add_recv();
702 recv->set_name(key);
703 TensorProto* value_proto = recv->mutable_tensor();
704 value.AsProtoTensorContent(value_proto);
705 }
706
mutable_step_stats()707 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
708 return response_.mutable_step_stats();
709 }
710
mutable_cost_graph()711 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
712 return response_.mutable_cost_graph();
713 }
714
status_code() const715 errors::Code OwnedProtoRunGraphResponse::status_code() const {
716 return response_.status_code();
717 }
718
status_error_message() const719 const string& OwnedProtoRunGraphResponse::status_error_message() const {
720 return response_.status_error_message();
721 }
722
set_status(const Status & status)723 void OwnedProtoRunGraphResponse::set_status(const Status& status) {
724 response_.set_status_code(status.code());
725 response_.set_status_error_message(status.error_message());
726 }
727
get_proto()728 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
729
num_partition_graphs() const730 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
731 return response_.partition_graph_size();
732 }
733
mutable_partition_graph(size_t i)734 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
735 return response_.mutable_partition_graph(i);
736 }
737
AddPartitionGraph(const GraphDef & partition_graph)738 void OwnedProtoRunGraphResponse::AddPartitionGraph(
739 const GraphDef& partition_graph) {
740 GraphDef* graph_def = response_.mutable_partition_graph()->Add();
741 *graph_def = partition_graph;
742 }
743
NonOwnedProtoRunGraphResponse(RunGraphResponse * response)744 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
745 RunGraphResponse* response)
746 : response_(response) {}
747
num_recvs() const748 size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
749 return response_->recv_size();
750 }
751
recv_key(size_t i) const752 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
753 return response_->recv(i).name();
754 }
755
RecvValue(size_t i,TensorProto * out_tensor)756 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
757 TensorProto* out_tensor) {
758 out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
759 return Status::OK();
760 }
761
RecvValue(size_t i,Tensor * out_tensor)762 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
763 if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
764 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
765 } else {
766 return Status::OK();
767 }
768 }
769
AddRecv(const string & key,const Tensor & value)770 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
771 const Tensor& value) {
772 NamedTensorProto* recv = response_->add_recv();
773 recv->set_name(key);
774 TensorProto* value_proto = recv->mutable_tensor();
775 value.AsProtoTensorContent(value_proto);
776 }
777
mutable_step_stats()778 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
779 return response_->mutable_step_stats();
780 }
781
mutable_cost_graph()782 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
783 return response_->mutable_cost_graph();
784 }
785
status_code() const786 errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
787 return response_->status_code();
788 }
789
status_error_message() const790 const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
791 return response_->status_error_message();
792 }
793
set_status(const Status & status)794 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
795 response_->set_status_code(status.code());
796 response_->set_status_error_message(status.error_message());
797 }
798
get_proto()799 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
800 return response_;
801 }
802
num_partition_graphs() const803 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
804 return response_->partition_graph_size();
805 }
806
mutable_partition_graph(size_t i)807 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
808 return response_->mutable_partition_graph(i);
809 }
810
AddPartitionGraph(const GraphDef & partition_graph)811 void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
812 const GraphDef& partition_graph) {
813 GraphDef* graph_def = response_->add_partition_graph();
814 *graph_def = partition_graph;
815 }
816
~MutableRunStepResponseWrapper()817 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
818
num_tensors() const819 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
820
tensor_name(size_t i) const821 const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
822 return tensors_[i].first;
823 }
824
TensorValue(size_t i,Tensor * out_tensor) const825 Status InMemoryRunStepResponse::TensorValue(size_t i,
826 Tensor* out_tensor) const {
827 *out_tensor = tensors_[i].second;
828 return Status::OK();
829 }
830
metadata() const831 const RunMetadata& InMemoryRunStepResponse::metadata() const {
832 return metadata_;
833 }
834
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * wrapper,size_t i)835 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
836 const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
837 Tensor tensor;
838 TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
839 tensors_.emplace_back(name, tensor);
840 return Status::OK();
841 }
842
mutable_metadata()843 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
844
status_code() const845 errors::Code InMemoryRunStepResponse::status_code() const {
846 return status_.code();
847 }
848
status_error_message() const849 const string& InMemoryRunStepResponse::status_error_message() const {
850 return status_.error_message();
851 }
852
set_status(const Status & status)853 void InMemoryRunStepResponse::set_status(const Status& status) {
854 status_ = status;
855 }
856
get_proto()857 RunStepResponse* InMemoryRunStepResponse::get_proto() {
858 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
859 return nullptr;
860 }
861
num_tensors() const862 size_t OwnedProtoRunStepResponse::num_tensors() const {
863 return response_.tensor_size();
864 }
865
tensor_name(size_t i) const866 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
867 return response_.tensor(i).name();
868 }
869
TensorValue(size_t i,Tensor * out_tensor) const870 Status OwnedProtoRunStepResponse::TensorValue(size_t i,
871 Tensor* out_tensor) const {
872 if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
873 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
874 } else {
875 return Status::OK();
876 }
877 }
878
metadata() const879 const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
880 return response_.metadata();
881 }
882
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)883 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
884 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
885 size_t i) {
886 NamedTensorProto* response_tensor = response_.add_tensor();
887 response_tensor->set_name(name);
888 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
889 }
890
mutable_metadata()891 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
892 return response_.mutable_metadata();
893 }
894
status_code() const895 errors::Code OwnedProtoRunStepResponse::status_code() const {
896 return response_.status_code();
897 }
898
status_error_message() const899 const string& OwnedProtoRunStepResponse::status_error_message() const {
900 return response_.status_error_message();
901 }
902
set_status(const Status & status)903 void OwnedProtoRunStepResponse::set_status(const Status& status) {
904 response_.set_status_code(status.code());
905 response_.set_status_error_message(status.error_message());
906 }
907
get_proto()908 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
909
NonOwnedProtoRunStepResponse(RunStepResponse * response)910 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
911 RunStepResponse* response)
912 : response_(response) {}
913
num_tensors() const914 size_t NonOwnedProtoRunStepResponse::num_tensors() const {
915 return response_->tensor_size();
916 }
917
tensor_name(size_t i) const918 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
919 return response_->tensor(i).name();
920 }
921
TensorValue(size_t i,Tensor * out_tensor) const922 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
923 Tensor* out_tensor) const {
924 if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
925 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
926 } else {
927 return Status::OK();
928 }
929 }
930
metadata() const931 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
932 return response_->metadata();
933 }
934
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)935 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
936 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
937 size_t i) {
938 NamedTensorProto* response_tensor = response_->add_tensor();
939 response_tensor->set_name(name);
940 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
941 }
942
mutable_metadata()943 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
944 return response_->mutable_metadata();
945 }
946
status_code() const947 errors::Code NonOwnedProtoRunStepResponse::status_code() const {
948 return response_->status_code();
949 }
950
status_error_message() const951 const string& NonOwnedProtoRunStepResponse::status_error_message() const {
952 return response_->status_error_message();
953 }
954
set_status(const Status & status)955 void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
956 response_->set_status_code(status.code());
957 response_->set_status_error_message(status.error_message());
958 }
959
get_proto()960 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
961
962 } // namespace tensorflow
963