1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
17 #include "tensorflow/core/framework/cost_graph.pb.h"
18 #include "tensorflow/core/framework/step_stats.pb.h"
19 #include "tensorflow/core/protobuf/config.pb.h"
20 #include "tensorflow/core/protobuf/named_tensor.pb.h"
21
22 namespace tensorflow {
23
24 namespace {
25
ParseTensorProtoToTensor(const TensorProto & tensor_proto,Tensor * out_tensor)26 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
27 Tensor* out_tensor) {
28 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
29 Tensor parsed(tensor_proto.dtype());
30 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
31 *out_tensor = parsed;
32 return true;
33 }
34 }
35 return false;
36 }
37
38 } // namespace
39
session_handle() const40 const string& InMemoryRunStepRequest::session_handle() const {
41 return session_handle_;
42 }
43
set_session_handle(const string & handle)44 void InMemoryRunStepRequest::set_session_handle(const string& handle) {
45 session_handle_ = handle;
46 }
47
partial_run_handle() const48 const string& InMemoryRunStepRequest::partial_run_handle() const {
49 return partial_run_handle_;
50 }
51
set_partial_run_handle(const string & handle)52 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
53 partial_run_handle_ = handle;
54 }
55
num_feeds() const56 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
feed_name(size_t i) const57 const string& InMemoryRunStepRequest::feed_name(size_t i) const {
58 return feeds_[i].first;
59 }
60
FeedValue(size_t i,Tensor * out_tensor) const61 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
62 *out_tensor = feeds_[i].second;
63 return Status::OK();
64 }
65
FeedValue(size_t i,TensorProto * out_tensor) const66 Status InMemoryRunStepRequest::FeedValue(size_t i,
67 TensorProto* out_tensor) const {
68 feeds_[i].second.AsProtoTensorContent(out_tensor);
69 return Status::OK();
70 }
71
add_feed(const string & name,const Tensor & value)72 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
73 feeds_.emplace_back(name, value);
74 }
75
num_fetches() const76 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
fetch_name(size_t i) const77 const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
78 return fetches_[i];
79 }
add_fetch(const string & name)80 void InMemoryRunStepRequest::add_fetch(const string& name) {
81 fetches_.push_back(name);
82 }
83
num_targets() const84 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
target_name(size_t i) const85 const string& InMemoryRunStepRequest::target_name(size_t i) const {
86 return targets_[i];
87 }
add_target(const string & name)88 void InMemoryRunStepRequest::add_target(const string& name) {
89 targets_.push_back(name);
90 }
91
options() const92 const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
93
mutable_options()94 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
95
store_errors_in_response_body() const96 bool InMemoryRunStepRequest::store_errors_in_response_body() const {
97 return store_errors_in_response_body_;
98 }
99
request_id() const100 int64 InMemoryRunStepRequest::request_id() const {
101 return 0; // no need to track request id for local version.
102 }
103
set_store_errors_in_response_body(bool store_errors)104 void InMemoryRunStepRequest::set_store_errors_in_response_body(
105 bool store_errors) {
106 store_errors_in_response_body_ = store_errors;
107 }
108
DebugString() const109 string InMemoryRunStepRequest::DebugString() const {
110 return ToProto().DebugString();
111 }
112
ToProto() const113 const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
114 if (!proto_version_) {
115 proto_version_.reset(new RunStepRequest);
116 proto_version_->set_session_handle(session_handle());
117 proto_version_->set_partial_run_handle(partial_run_handle());
118 for (size_t i = 0; i < num_feeds(); ++i) {
119 auto feed = proto_version_->add_feed();
120 feed->set_name(feed_name(i));
121 feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
122 }
123 for (size_t i = 0; i < num_fetches(); ++i) {
124 proto_version_->add_fetch(fetch_name(i));
125 }
126 for (size_t i = 0; i < num_targets(); ++i) {
127 proto_version_->add_target(target_name(i));
128 }
129 *proto_version_->mutable_options() = options();
130 }
131 return *proto_version_;
132 }
133
session_handle() const134 const string& MutableProtoRunStepRequest::session_handle() const {
135 return request_.session_handle();
136 }
set_session_handle(const string & handle)137 void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
138 request_.set_session_handle(handle);
139 }
140
partial_run_handle() const141 const string& MutableProtoRunStepRequest::partial_run_handle() const {
142 return request_.partial_run_handle();
143 }
set_partial_run_handle(const string & handle)144 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
145 request_.set_partial_run_handle(handle);
146 }
147
num_feeds() const148 size_t MutableProtoRunStepRequest::num_feeds() const {
149 return request_.feed_size();
150 }
feed_name(size_t i) const151 const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
152 return request_.feed(i).name();
153 }
FeedValue(size_t i,Tensor * out_tensor) const154 Status MutableProtoRunStepRequest::FeedValue(size_t i,
155 Tensor* out_tensor) const {
156 if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
157 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
158 } else {
159 return Status::OK();
160 }
161 }
162
FeedValue(size_t i,TensorProto * out_tensor) const163 Status MutableProtoRunStepRequest::FeedValue(size_t i,
164 TensorProto* out_tensor) const {
165 *out_tensor = request_.feed(i).tensor();
166 return Status::OK();
167 }
168
add_feed(const string & name,const Tensor & value)169 void MutableProtoRunStepRequest::add_feed(const string& name,
170 const Tensor& value) {
171 NamedTensorProto* feed = request_.add_feed();
172 feed->set_name(name);
173 TensorProto* value_proto = feed->mutable_tensor();
174 value.AsProtoTensorContent(value_proto);
175 }
176
num_fetches() const177 size_t MutableProtoRunStepRequest::num_fetches() const {
178 return request_.fetch_size();
179 }
180
fetch_name(size_t i) const181 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
182 return request_.fetch(i);
183 }
add_fetch(const string & name)184 void MutableProtoRunStepRequest::add_fetch(const string& name) {
185 request_.add_fetch(name);
186 }
187
num_targets() const188 size_t MutableProtoRunStepRequest::num_targets() const {
189 return request_.target_size();
190 }
191
target_name(size_t i) const192 const string& MutableProtoRunStepRequest::target_name(size_t i) const {
193 return request_.target(i);
194 }
195
add_target(const string & name)196 void MutableProtoRunStepRequest::add_target(const string& name) {
197 request_.add_target(name);
198 }
199
options() const200 const RunOptions& MutableProtoRunStepRequest::options() const {
201 return request_.options();
202 }
203
mutable_options()204 RunOptions* MutableProtoRunStepRequest::mutable_options() {
205 return request_.mutable_options();
206 }
207
store_errors_in_response_body() const208 bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
209 return request_.store_errors_in_response_body();
210 }
211
set_store_errors_in_response_body(bool store_errors)212 void MutableProtoRunStepRequest::set_store_errors_in_response_body(
213 bool store_errors) {
214 request_.set_store_errors_in_response_body(store_errors);
215 }
216
request_id() const217 int64 MutableProtoRunStepRequest::request_id() const {
218 return request_.request_id();
219 }
220
DebugString() const221 string MutableProtoRunStepRequest::DebugString() const {
222 return request_.DebugString();
223 }
224
ToProto() const225 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
226 return request_;
227 }
228
ProtoRunStepRequest(const RunStepRequest * request)229 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
230 : request_(request) {}
231
session_handle() const232 const string& ProtoRunStepRequest::session_handle() const {
233 return request_->session_handle();
234 }
235
partial_run_handle() const236 const string& ProtoRunStepRequest::partial_run_handle() const {
237 return request_->partial_run_handle();
238 }
239
num_feeds() const240 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
241
feed_name(size_t i) const242 const string& ProtoRunStepRequest::feed_name(size_t i) const {
243 return request_->feed(i).name();
244 }
245
FeedValue(size_t i,Tensor * out_tensor) const246 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
247 if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
248 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
249 } else {
250 return Status::OK();
251 }
252 }
253
FeedValue(size_t i,TensorProto * out_tensor) const254 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
255 *out_tensor = request_->feed(i).tensor();
256 return Status::OK();
257 }
258
num_fetches() const259 size_t ProtoRunStepRequest::num_fetches() const {
260 return request_->fetch_size();
261 }
262
fetch_name(size_t i) const263 const string& ProtoRunStepRequest::fetch_name(size_t i) const {
264 return request_->fetch(i);
265 }
266
num_targets() const267 size_t ProtoRunStepRequest::num_targets() const {
268 return request_->target_size();
269 }
270
target_name(size_t i) const271 const string& ProtoRunStepRequest::target_name(size_t i) const {
272 return request_->target(i);
273 }
274
options() const275 const RunOptions& ProtoRunStepRequest::options() const {
276 return request_->options();
277 }
278
store_errors_in_response_body() const279 bool ProtoRunStepRequest::store_errors_in_response_body() const {
280 return request_->store_errors_in_response_body();
281 }
282
request_id() const283 int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
284
DebugString() const285 string ProtoRunStepRequest::DebugString() const {
286 return request_->DebugString();
287 }
288
ToProto() const289 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
290
session_handle() const291 const string& InMemoryRunGraphRequest::session_handle() const {
292 return session_handle_;
293 }
294
create_worker_session_called() const295 bool InMemoryRunGraphRequest::create_worker_session_called() const {
296 return create_worker_session_called_;
297 }
298
set_session_handle(const string & handle)299 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
300 session_handle_ = handle;
301 }
302
set_create_worker_session_called(bool called)303 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
304 create_worker_session_called_ = called;
305 }
306
graph_handle() const307 const string& InMemoryRunGraphRequest::graph_handle() const {
308 return graph_handle_;
309 }
310
set_graph_handle(const string & handle)311 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
312 graph_handle_ = handle;
313 }
314
step_id() const315 int64 InMemoryRunGraphRequest::step_id() const { return step_id_; }
316
set_step_id(int64 step_id)317 void InMemoryRunGraphRequest::set_step_id(int64 step_id) { step_id_ = step_id; }
318
exec_opts() const319 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
320 return exec_opts_;
321 }
322
mutable_exec_opts()323 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
324 return &exec_opts_;
325 }
326
num_sends() const327 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
328
send_key(size_t i) const329 const string& InMemoryRunGraphRequest::send_key(size_t i) const {
330 return sends_[i].first;
331 }
332
SendValue(size_t i,Tensor * out_tensor) const333 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
334 *out_tensor = sends_[i].second;
335 return Status::OK();
336 }
337
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)338 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
339 const RunStepRequestWrapper& run_step_request, size_t i,
340 const string& send_key) {
341 Tensor tensor;
342 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
343 sends_.emplace_back(send_key, std::move(tensor));
344 return Status::OK();
345 }
346
347 // TODO(b/74355905): Add a specialized implementation that avoids
348 // copying the tensor when at least two of the {client, master,
349 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)350 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
351 const RunCallableRequest& run_callable_request, size_t i,
352 const string& send_key) {
353 Tensor tensor;
354 if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
355 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
356 }
357 sends_.emplace_back(send_key, std::move(tensor));
358 return Status::OK();
359 }
360
num_recvs() const361 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
362
recv_key(size_t i) const363 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
364 return recvs_[i];
365 }
366
add_recv_key(const string & recv_key)367 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
368 recvs_.push_back(recv_key);
369 }
370
is_partial() const371 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
372
set_is_partial(bool is_partial)373 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
374 is_partial_ = is_partial;
375 }
376
is_last_partial_run() const377 bool InMemoryRunGraphRequest::is_last_partial_run() const {
378 return is_last_partial_run_;
379 }
380
set_is_last_partial_run(bool is_last_partial_run)381 void InMemoryRunGraphRequest::set_is_last_partial_run(
382 bool is_last_partial_run) {
383 is_last_partial_run_ = is_last_partial_run;
384 }
385
store_errors_in_response_body() const386 bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
387 return store_errors_in_response_body_;
388 }
389
set_store_errors_in_response_body(bool store_errors)390 void InMemoryRunGraphRequest::set_store_errors_in_response_body(
391 bool store_errors) {
392 store_errors_in_response_body_ = store_errors;
393 }
394
ToProto() const395 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
396 if (!proto_version_) {
397 proto_version_.reset(new RunGraphRequest);
398 proto_version_->set_session_handle(session_handle());
399 proto_version_->set_create_worker_session_called(
400 create_worker_session_called());
401 proto_version_->set_graph_handle(graph_handle());
402 proto_version_->set_step_id(step_id());
403 *proto_version_->mutable_exec_opts() = exec_opts();
404 for (size_t i = 0; i < num_sends(); ++i) {
405 auto send = proto_version_->add_send();
406 send->set_name(send_key(i));
407 sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
408 }
409 for (size_t i = 0; i < num_recvs(); ++i) {
410 proto_version_->add_recv_key(recv_key(i));
411 }
412 proto_version_->set_is_partial(is_partial());
413 proto_version_->set_is_last_partial_run(is_last_partial_run());
414 }
415 return *proto_version_;
416 }
417
session_handle() const418 const string& MutableProtoRunGraphRequest::session_handle() const {
419 return request_.session_handle();
420 }
421
set_session_handle(const string & handle)422 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
423 request_.set_session_handle(handle);
424 }
425
create_worker_session_called() const426 bool MutableProtoRunGraphRequest::create_worker_session_called() const {
427 return request_.create_worker_session_called();
428 }
429
set_create_worker_session_called(bool called)430 void MutableProtoRunGraphRequest::set_create_worker_session_called(
431 bool called) {
432 request_.set_create_worker_session_called(called);
433 }
434
graph_handle() const435 const string& MutableProtoRunGraphRequest::graph_handle() const {
436 return request_.graph_handle();
437 }
438
set_graph_handle(const string & handle)439 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
440 request_.set_graph_handle(handle);
441 }
442
step_id() const443 int64 MutableProtoRunGraphRequest::step_id() const {
444 return request_.step_id();
445 }
446
set_step_id(int64 step_id)447 void MutableProtoRunGraphRequest::set_step_id(int64 step_id) {
448 request_.set_step_id(step_id);
449 }
450
exec_opts() const451 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
452 return request_.exec_opts();
453 }
454
mutable_exec_opts()455 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
456 return request_.mutable_exec_opts();
457 }
458
num_sends() const459 size_t MutableProtoRunGraphRequest::num_sends() const {
460 return request_.send_size();
461 }
462
send_key(size_t i) const463 const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
464 return request_.send(i).name();
465 }
466
SendValue(size_t i,Tensor * out_tensor) const467 Status MutableProtoRunGraphRequest::SendValue(size_t i,
468 Tensor* out_tensor) const {
469 if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
470 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
471 } else {
472 return Status::OK();
473 }
474 }
475
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)476 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
477 const RunStepRequestWrapper& run_step_request, size_t i,
478 const string& send_key) {
479 NamedTensorProto* send = request_.add_send();
480 send->set_name(send_key);
481 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
482 return Status::OK();
483 }
484
485 // TODO(b/74355905): Add a specialized implementation that avoids
486 // copying the tensor when at least two of the {client, master,
487 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)488 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
489 const RunCallableRequest& run_callable_request, size_t i,
490 const string& send_key) {
491 NamedTensorProto* send = request_.add_send();
492 send->set_name(send_key);
493 *send->mutable_tensor() = run_callable_request.feed(i);
494 return Status::OK();
495 }
496
num_recvs() const497 size_t MutableProtoRunGraphRequest::num_recvs() const {
498 return request_.recv_key_size();
499 }
500
recv_key(size_t i) const501 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
502 return request_.recv_key(i);
503 }
504
add_recv_key(const string & recv_key)505 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
506 request_.add_recv_key(recv_key);
507 }
508
is_partial() const509 bool MutableProtoRunGraphRequest::is_partial() const {
510 return request_.is_partial();
511 }
512
set_is_partial(bool is_partial)513 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
514 request_.set_is_partial(is_partial);
515 }
516
is_last_partial_run() const517 bool MutableProtoRunGraphRequest::is_last_partial_run() const {
518 return request_.is_last_partial_run();
519 }
520
set_is_last_partial_run(bool is_last_partial_run)521 void MutableProtoRunGraphRequest::set_is_last_partial_run(
522 bool is_last_partial_run) {
523 request_.set_is_last_partial_run(is_last_partial_run);
524 }
525
store_errors_in_response_body() const526 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
527 return request_.store_errors_in_response_body();
528 }
529
set_store_errors_in_response_body(bool store_errors)530 void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
531 bool store_errors) {
532 request_.set_store_errors_in_response_body(store_errors);
533 }
534
ToProto() const535 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
536 return request_;
537 }
538
ProtoRunGraphRequest(const RunGraphRequest * request)539 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
540 : request_(request) {}
541
session_handle() const542 const string& ProtoRunGraphRequest::session_handle() const {
543 return request_->session_handle();
544 }
545
create_worker_session_called() const546 bool ProtoRunGraphRequest::create_worker_session_called() const {
547 return request_->create_worker_session_called();
548 }
549
graph_handle() const550 const string& ProtoRunGraphRequest::graph_handle() const {
551 return request_->graph_handle();
552 }
553
step_id() const554 int64 ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
555
exec_opts() const556 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
557 return request_->exec_opts();
558 }
559
num_sends() const560 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
561
send_key(size_t i) const562 const string& ProtoRunGraphRequest::send_key(size_t i) const {
563 return request_->send(i).name();
564 }
565
SendValue(size_t i,Tensor * out_tensor) const566 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
567 if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
568 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
569 } else {
570 return Status::OK();
571 }
572 }
573
num_recvs() const574 size_t ProtoRunGraphRequest::num_recvs() const {
575 return request_->recv_key_size();
576 }
577
recv_key(size_t i) const578 const string& ProtoRunGraphRequest::recv_key(size_t i) const {
579 return request_->recv_key(i);
580 }
581
is_partial() const582 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
583
is_last_partial_run() const584 bool ProtoRunGraphRequest::is_last_partial_run() const {
585 return request_->is_last_partial_run();
586 }
587
store_errors_in_response_body() const588 bool ProtoRunGraphRequest::store_errors_in_response_body() const {
589 return request_->store_errors_in_response_body();
590 }
591
ToProto() const592 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
593 return *request_;
594 }
595
num_recvs() const596 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
597
recv_key(size_t i) const598 const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
599 return recvs_[i].first;
600 }
601
RecvValue(size_t i,TensorProto * out_tensor)602 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
603 recvs_[i].second.AsProtoTensorContent(out_tensor);
604 return Status::OK();
605 }
606
RecvValue(size_t i,Tensor * out_tensor)607 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
608 *out_tensor = recvs_[i].second;
609 return Status::OK();
610 }
611
AddRecv(const string & key,const Tensor & value)612 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
613 recvs_.emplace_back(key, value);
614 }
615
mutable_step_stats()616 StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
617 return &step_stats_;
618 }
619
mutable_cost_graph()620 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
621 return &cost_graph_;
622 }
623
status_code() const624 errors::Code InMemoryRunGraphResponse::status_code() const {
625 return status_.code();
626 }
627
status_error_message() const628 const string& InMemoryRunGraphResponse::status_error_message() const {
629 return status_.error_message();
630 }
631
set_status(const Status & status)632 void InMemoryRunGraphResponse::set_status(const Status& status) {
633 status_ = status;
634 }
635
get_proto()636 RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
637 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
638 return nullptr;
639 }
640
num_partition_graphs() const641 size_t InMemoryRunGraphResponse::num_partition_graphs() const {
642 return partition_graphs_.size();
643 }
644
mutable_partition_graph(size_t i)645 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
646 return &partition_graphs_[i];
647 }
648
AddPartitionGraph(const GraphDef & partition_graph)649 void InMemoryRunGraphResponse::AddPartitionGraph(
650 const GraphDef& partition_graph) {
651 partition_graphs_.push_back(partition_graph);
652 }
653
num_recvs() const654 size_t OwnedProtoRunGraphResponse::num_recvs() const {
655 return response_.recv_size();
656 }
657
recv_key(size_t i) const658 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
659 return response_.recv(i).name();
660 }
661
RecvValue(size_t i,TensorProto * out_tensor)662 Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
663 TensorProto* out_tensor) {
664 out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
665 return Status::OK();
666 }
667
RecvValue(size_t i,Tensor * out_tensor)668 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
669 if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
670 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
671 } else {
672 return Status::OK();
673 }
674 }
675
AddRecv(const string & key,const Tensor & value)676 void OwnedProtoRunGraphResponse::AddRecv(const string& key,
677 const Tensor& value) {
678 NamedTensorProto* recv = response_.add_recv();
679 recv->set_name(key);
680 TensorProto* value_proto = recv->mutable_tensor();
681 value.AsProtoTensorContent(value_proto);
682 }
683
mutable_step_stats()684 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
685 return response_.mutable_step_stats();
686 }
687
mutable_cost_graph()688 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
689 return response_.mutable_cost_graph();
690 }
691
status_code() const692 errors::Code OwnedProtoRunGraphResponse::status_code() const {
693 return response_.status_code();
694 }
695
status_error_message() const696 const string& OwnedProtoRunGraphResponse::status_error_message() const {
697 return response_.status_error_message();
698 }
699
set_status(const Status & status)700 void OwnedProtoRunGraphResponse::set_status(const Status& status) {
701 response_.set_status_code(status.code());
702 response_.set_status_error_message(status.error_message());
703 }
704
get_proto()705 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
706
num_partition_graphs() const707 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
708 return response_.partition_graph_size();
709 }
710
mutable_partition_graph(size_t i)711 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
712 return response_.mutable_partition_graph(i);
713 }
714
AddPartitionGraph(const GraphDef & partition_graph)715 void OwnedProtoRunGraphResponse::AddPartitionGraph(
716 const GraphDef& partition_graph) {
717 GraphDef* graph_def = response_.mutable_partition_graph()->Add();
718 *graph_def = partition_graph;
719 }
720
NonOwnedProtoRunGraphResponse(RunGraphResponse * response)721 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
722 RunGraphResponse* response)
723 : response_(response) {}
724
num_recvs() const725 size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
726 return response_->recv_size();
727 }
728
recv_key(size_t i) const729 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
730 return response_->recv(i).name();
731 }
732
RecvValue(size_t i,TensorProto * out_tensor)733 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
734 TensorProto* out_tensor) {
735 out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
736 return Status::OK();
737 }
738
RecvValue(size_t i,Tensor * out_tensor)739 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
740 if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
741 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
742 } else {
743 return Status::OK();
744 }
745 }
746
AddRecv(const string & key,const Tensor & value)747 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
748 const Tensor& value) {
749 NamedTensorProto* recv = response_->add_recv();
750 recv->set_name(key);
751 TensorProto* value_proto = recv->mutable_tensor();
752 value.AsProtoTensorContent(value_proto);
753 }
754
mutable_step_stats()755 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
756 return response_->mutable_step_stats();
757 }
758
mutable_cost_graph()759 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
760 return response_->mutable_cost_graph();
761 }
762
status_code() const763 errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
764 return response_->status_code();
765 }
766
status_error_message() const767 const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
768 return response_->status_error_message();
769 }
770
set_status(const Status & status)771 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
772 response_->set_status_code(status.code());
773 response_->set_status_error_message(status.error_message());
774 }
775
get_proto()776 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
777 return response_;
778 }
779
num_partition_graphs() const780 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
781 return response_->partition_graph_size();
782 }
783
mutable_partition_graph(size_t i)784 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
785 return response_->mutable_partition_graph(i);
786 }
787
AddPartitionGraph(const GraphDef & partition_graph)788 void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
789 const GraphDef& partition_graph) {
790 GraphDef* graph_def = response_->add_partition_graph();
791 *graph_def = partition_graph;
792 }
793
~MutableRunStepResponseWrapper()794 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
795
num_tensors() const796 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
797
tensor_name(size_t i) const798 const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
799 return tensors_[i].first;
800 }
801
TensorValue(size_t i,Tensor * out_tensor) const802 Status InMemoryRunStepResponse::TensorValue(size_t i,
803 Tensor* out_tensor) const {
804 *out_tensor = tensors_[i].second;
805 return Status::OK();
806 }
807
metadata() const808 const RunMetadata& InMemoryRunStepResponse::metadata() const {
809 return metadata_;
810 }
811
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * wrapper,size_t i)812 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
813 const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
814 Tensor tensor;
815 TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
816 tensors_.emplace_back(name, tensor);
817 return Status::OK();
818 }
819
mutable_metadata()820 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
821
status_code() const822 errors::Code InMemoryRunStepResponse::status_code() const {
823 return status_.code();
824 }
825
status_error_message() const826 const string& InMemoryRunStepResponse::status_error_message() const {
827 return status_.error_message();
828 }
829
set_status(const Status & status)830 void InMemoryRunStepResponse::set_status(const Status& status) {
831 status_ = status;
832 }
833
get_proto()834 RunStepResponse* InMemoryRunStepResponse::get_proto() {
835 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
836 return nullptr;
837 }
838
num_tensors() const839 size_t OwnedProtoRunStepResponse::num_tensors() const {
840 return response_.tensor_size();
841 }
842
tensor_name(size_t i) const843 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
844 return response_.tensor(i).name();
845 }
846
TensorValue(size_t i,Tensor * out_tensor) const847 Status OwnedProtoRunStepResponse::TensorValue(size_t i,
848 Tensor* out_tensor) const {
849 if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
850 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
851 } else {
852 return Status::OK();
853 }
854 }
855
metadata() const856 const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
857 return response_.metadata();
858 }
859
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)860 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
861 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
862 size_t i) {
863 NamedTensorProto* response_tensor = response_.add_tensor();
864 response_tensor->set_name(name);
865 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
866 }
867
mutable_metadata()868 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
869 return response_.mutable_metadata();
870 }
871
status_code() const872 errors::Code OwnedProtoRunStepResponse::status_code() const {
873 return response_.status_code();
874 }
875
status_error_message() const876 const string& OwnedProtoRunStepResponse::status_error_message() const {
877 return response_.status_error_message();
878 }
879
set_status(const Status & status)880 void OwnedProtoRunStepResponse::set_status(const Status& status) {
881 response_.set_status_code(status.code());
882 response_.set_status_error_message(status.error_message());
883 }
884
get_proto()885 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
886
NonOwnedProtoRunStepResponse(RunStepResponse * response)887 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
888 RunStepResponse* response)
889 : response_(response) {}
890
num_tensors() const891 size_t NonOwnedProtoRunStepResponse::num_tensors() const {
892 return response_->tensor_size();
893 }
894
tensor_name(size_t i) const895 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
896 return response_->tensor(i).name();
897 }
898
TensorValue(size_t i,Tensor * out_tensor) const899 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
900 Tensor* out_tensor) const {
901 if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
902 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
903 } else {
904 return Status::OK();
905 }
906 }
907
metadata() const908 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
909 return response_->metadata();
910 }
911
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)912 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
913 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
914 size_t i) {
915 NamedTensorProto* response_tensor = response_->add_tensor();
916 response_tensor->set_name(name);
917 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
918 }
919
mutable_metadata()920 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
921 return response_->mutable_metadata();
922 }
923
status_code() const924 errors::Code NonOwnedProtoRunStepResponse::status_code() const {
925 return response_->status_code();
926 }
927
status_error_message() const928 const string& NonOwnedProtoRunStepResponse::status_error_message() const {
929 return response_->status_error_message();
930 }
931
set_status(const Status & status)932 void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
933 response_->set_status_code(status.code());
934 response_->set_status_error_message(status.error_message());
935 }
936
get_proto()937 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
938
939 } // namespace tensorflow
940