1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
17
18 #include "tensorflow/core/framework/cost_graph.pb.h"
19 #include "tensorflow/core/framework/step_stats.pb.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/protobuf/config.pb.h"
22 #include "tensorflow/core/protobuf/named_tensor.pb.h"
23
24 namespace tensorflow {
25
ParseTensorProtoToTensor(const TensorProto & tensor_proto,Tensor * out_tensor)26 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
27 Tensor* out_tensor) {
28 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
29 Tensor parsed(tensor_proto.dtype());
30 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
31 *out_tensor = parsed;
32 return true;
33 }
34 }
35 return false;
36 }
37
session_handle() const38 const string& InMemoryRunStepRequest::session_handle() const {
39 return session_handle_;
40 }
41
set_session_handle(const string & handle)42 void InMemoryRunStepRequest::set_session_handle(const string& handle) {
43 session_handle_ = handle;
44 }
45
partial_run_handle() const46 const string& InMemoryRunStepRequest::partial_run_handle() const {
47 return partial_run_handle_;
48 }
49
set_partial_run_handle(const string & handle)50 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
51 partial_run_handle_ = handle;
52 }
53
num_feeds() const54 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
feed_name(size_t i) const55 const string& InMemoryRunStepRequest::feed_name(size_t i) const {
56 return feeds_[i].first;
57 }
58
FeedValue(size_t i,Tensor * out_tensor) const59 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
60 *out_tensor = feeds_[i].second;
61 return Status::OK();
62 }
63
FeedValue(size_t i,TensorProto * out_tensor) const64 Status InMemoryRunStepRequest::FeedValue(size_t i,
65 TensorProto* out_tensor) const {
66 feeds_[i].second.AsProtoTensorContent(out_tensor);
67 return Status::OK();
68 }
69
add_feed(const string & name,const Tensor & value)70 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
71 feeds_.emplace_back(name, value);
72 }
73
num_fetches() const74 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
fetch_name(size_t i) const75 const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
76 return fetches_[i];
77 }
add_fetch(const string & name)78 void InMemoryRunStepRequest::add_fetch(const string& name) {
79 fetches_.push_back(name);
80 }
81
num_targets() const82 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
target_name(size_t i) const83 const string& InMemoryRunStepRequest::target_name(size_t i) const {
84 return targets_[i];
85 }
add_target(const string & name)86 void InMemoryRunStepRequest::add_target(const string& name) {
87 targets_.push_back(name);
88 }
89
options() const90 const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
91
mutable_options()92 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
93
store_errors_in_response_body() const94 bool InMemoryRunStepRequest::store_errors_in_response_body() const {
95 return store_errors_in_response_body_;
96 }
97
request_id() const98 int64 InMemoryRunStepRequest::request_id() const {
99 return 0; // no need to track request id for local version.
100 }
101
set_store_errors_in_response_body(bool store_errors)102 void InMemoryRunStepRequest::set_store_errors_in_response_body(
103 bool store_errors) {
104 store_errors_in_response_body_ = store_errors;
105 }
106
DebugString() const107 string InMemoryRunStepRequest::DebugString() const {
108 return ToProto().DebugString();
109 }
110
ToProto() const111 const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
112 if (!proto_version_) {
113 proto_version_.reset(new RunStepRequest);
114 proto_version_->set_session_handle(session_handle());
115 proto_version_->set_partial_run_handle(partial_run_handle());
116 for (size_t i = 0; i < num_feeds(); ++i) {
117 auto feed = proto_version_->add_feed();
118 feed->set_name(feed_name(i));
119 feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
120 }
121 for (size_t i = 0; i < num_fetches(); ++i) {
122 proto_version_->add_fetch(fetch_name(i));
123 }
124 for (size_t i = 0; i < num_targets(); ++i) {
125 proto_version_->add_target(target_name(i));
126 }
127 *proto_version_->mutable_options() = options();
128 }
129 return *proto_version_;
130 }
131
session_handle() const132 const string& MutableProtoRunStepRequest::session_handle() const {
133 return request_.session_handle();
134 }
set_session_handle(const string & handle)135 void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
136 request_.set_session_handle(handle);
137 }
138
partial_run_handle() const139 const string& MutableProtoRunStepRequest::partial_run_handle() const {
140 return request_.partial_run_handle();
141 }
set_partial_run_handle(const string & handle)142 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
143 request_.set_partial_run_handle(handle);
144 }
145
num_feeds() const146 size_t MutableProtoRunStepRequest::num_feeds() const {
147 return request_.feed_size();
148 }
feed_name(size_t i) const149 const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
150 return request_.feed(i).name();
151 }
FeedValue(size_t i,Tensor * out_tensor) const152 Status MutableProtoRunStepRequest::FeedValue(size_t i,
153 Tensor* out_tensor) const {
154 if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
155 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
156 } else {
157 return Status::OK();
158 }
159 }
160
FeedValue(size_t i,TensorProto * out_tensor) const161 Status MutableProtoRunStepRequest::FeedValue(size_t i,
162 TensorProto* out_tensor) const {
163 *out_tensor = request_.feed(i).tensor();
164 return Status::OK();
165 }
166
add_feed(const string & name,const Tensor & value)167 void MutableProtoRunStepRequest::add_feed(const string& name,
168 const Tensor& value) {
169 NamedTensorProto* feed = request_.add_feed();
170 feed->set_name(name);
171 TensorProto* value_proto = feed->mutable_tensor();
172 value.AsProtoTensorContent(value_proto);
173 }
174
num_fetches() const175 size_t MutableProtoRunStepRequest::num_fetches() const {
176 return request_.fetch_size();
177 }
178
fetch_name(size_t i) const179 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
180 return request_.fetch(i);
181 }
add_fetch(const string & name)182 void MutableProtoRunStepRequest::add_fetch(const string& name) {
183 request_.add_fetch(name);
184 }
185
num_targets() const186 size_t MutableProtoRunStepRequest::num_targets() const {
187 return request_.target_size();
188 }
189
target_name(size_t i) const190 const string& MutableProtoRunStepRequest::target_name(size_t i) const {
191 return request_.target(i);
192 }
193
add_target(const string & name)194 void MutableProtoRunStepRequest::add_target(const string& name) {
195 request_.add_target(name);
196 }
197
options() const198 const RunOptions& MutableProtoRunStepRequest::options() const {
199 return request_.options();
200 }
201
mutable_options()202 RunOptions* MutableProtoRunStepRequest::mutable_options() {
203 return request_.mutable_options();
204 }
205
store_errors_in_response_body() const206 bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
207 return request_.store_errors_in_response_body();
208 }
209
set_store_errors_in_response_body(bool store_errors)210 void MutableProtoRunStepRequest::set_store_errors_in_response_body(
211 bool store_errors) {
212 request_.set_store_errors_in_response_body(store_errors);
213 }
214
request_id() const215 int64 MutableProtoRunStepRequest::request_id() const {
216 return request_.request_id();
217 }
218
DebugString() const219 string MutableProtoRunStepRequest::DebugString() const {
220 return request_.DebugString();
221 }
222
ToProto() const223 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
224 return request_;
225 }
226
ProtoRunStepRequest(const RunStepRequest * request)227 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
228 : request_(request) {}
229
session_handle() const230 const string& ProtoRunStepRequest::session_handle() const {
231 return request_->session_handle();
232 }
233
partial_run_handle() const234 const string& ProtoRunStepRequest::partial_run_handle() const {
235 return request_->partial_run_handle();
236 }
237
num_feeds() const238 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
239
feed_name(size_t i) const240 const string& ProtoRunStepRequest::feed_name(size_t i) const {
241 return request_->feed(i).name();
242 }
243
FeedValue(size_t i,Tensor * out_tensor) const244 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
245 if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
246 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
247 } else {
248 return Status::OK();
249 }
250 }
251
FeedValue(size_t i,TensorProto * out_tensor) const252 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
253 *out_tensor = request_->feed(i).tensor();
254 return Status::OK();
255 }
256
num_fetches() const257 size_t ProtoRunStepRequest::num_fetches() const {
258 return request_->fetch_size();
259 }
260
fetch_name(size_t i) const261 const string& ProtoRunStepRequest::fetch_name(size_t i) const {
262 return request_->fetch(i);
263 }
264
num_targets() const265 size_t ProtoRunStepRequest::num_targets() const {
266 return request_->target_size();
267 }
268
target_name(size_t i) const269 const string& ProtoRunStepRequest::target_name(size_t i) const {
270 return request_->target(i);
271 }
272
options() const273 const RunOptions& ProtoRunStepRequest::options() const {
274 return request_->options();
275 }
276
store_errors_in_response_body() const277 bool ProtoRunStepRequest::store_errors_in_response_body() const {
278 return request_->store_errors_in_response_body();
279 }
280
request_id() const281 int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
282
DebugString() const283 string ProtoRunStepRequest::DebugString() const {
284 return request_->DebugString();
285 }
286
ToProto() const287 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
288
session_handle() const289 const string& InMemoryRunGraphRequest::session_handle() const {
290 return session_handle_;
291 }
292
create_worker_session_called() const293 bool InMemoryRunGraphRequest::create_worker_session_called() const {
294 return create_worker_session_called_;
295 }
296
set_session_handle(const string & handle)297 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
298 session_handle_ = handle;
299 }
300
set_create_worker_session_called(bool called)301 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
302 create_worker_session_called_ = called;
303 }
304
graph_handle() const305 const string& InMemoryRunGraphRequest::graph_handle() const {
306 return graph_handle_;
307 }
308
set_graph_handle(const string & handle)309 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
310 graph_handle_ = handle;
311 }
312
step_id() const313 int64 InMemoryRunGraphRequest::step_id() const { return step_id_; }
314
set_step_id(int64_t step_id)315 void InMemoryRunGraphRequest::set_step_id(int64_t step_id) {
316 step_id_ = step_id;
317 }
318
exec_opts() const319 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
320 return exec_opts_;
321 }
322
mutable_exec_opts()323 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
324 return &exec_opts_;
325 }
326
num_sends() const327 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
328
send_key(size_t i) const329 const string& InMemoryRunGraphRequest::send_key(size_t i) const {
330 return sends_[i].first;
331 }
332
SendValue(size_t i,Tensor * out_tensor) const333 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
334 *out_tensor = sends_[i].second;
335 return Status::OK();
336 }
337
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)338 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
339 const RunStepRequestWrapper& run_step_request, size_t i,
340 const string& send_key) {
341 Tensor tensor;
342 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
343 sends_.emplace_back(send_key, std::move(tensor));
344 return Status::OK();
345 }
346
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)347 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
348 const RunCallableRequest& run_callable_request, size_t i,
349 const string& send_key) {
350 Tensor tensor;
351 if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
352 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
353 }
354 sends_.emplace_back(send_key, std::move(tensor));
355 return Status::OK();
356 }
357
num_recvs() const358 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
359
recv_key(size_t i) const360 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
361 return recvs_[i];
362 }
363
add_recv_key(const string & recv_key)364 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
365 recvs_.push_back(recv_key);
366 }
367
is_partial() const368 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
369
set_is_partial(bool is_partial)370 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
371 is_partial_ = is_partial;
372 }
373
is_last_partial_run() const374 bool InMemoryRunGraphRequest::is_last_partial_run() const {
375 return is_last_partial_run_;
376 }
377
set_is_last_partial_run(bool is_last_partial_run)378 void InMemoryRunGraphRequest::set_is_last_partial_run(
379 bool is_last_partial_run) {
380 is_last_partial_run_ = is_last_partial_run;
381 }
382
store_errors_in_response_body() const383 bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
384 return store_errors_in_response_body_;
385 }
386
set_store_errors_in_response_body(bool store_errors)387 void InMemoryRunGraphRequest::set_store_errors_in_response_body(
388 bool store_errors) {
389 store_errors_in_response_body_ = store_errors;
390 }
391
request_id() const392 int64 InMemoryRunGraphRequest::request_id() const { return request_id_; }
393
set_request_id(int64_t request_id)394 void InMemoryRunGraphRequest::set_request_id(int64_t request_id) {
395 request_id_ = request_id;
396 }
397
ToProto() const398 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
399 if (!proto_version_) {
400 proto_version_.reset(new RunGraphRequest);
401 proto_version_->set_session_handle(session_handle());
402 proto_version_->set_create_worker_session_called(
403 create_worker_session_called());
404 proto_version_->set_graph_handle(graph_handle());
405 proto_version_->set_step_id(step_id());
406 *proto_version_->mutable_exec_opts() = exec_opts();
407 for (size_t i = 0; i < num_sends(); ++i) {
408 auto send = proto_version_->add_send();
409 send->set_name(send_key(i));
410 sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
411 }
412 for (size_t i = 0; i < num_recvs(); ++i) {
413 proto_version_->add_recv_key(recv_key(i));
414 }
415 proto_version_->set_is_partial(is_partial());
416 proto_version_->set_is_last_partial_run(is_last_partial_run());
417 }
418 proto_version_->set_store_errors_in_response_body(
419 store_errors_in_response_body_);
420 proto_version_->set_request_id(request_id_);
421 return *proto_version_;
422 }
423
session_handle() const424 const string& MutableProtoRunGraphRequest::session_handle() const {
425 return request_.session_handle();
426 }
427
set_session_handle(const string & handle)428 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
429 request_.set_session_handle(handle);
430 }
431
create_worker_session_called() const432 bool MutableProtoRunGraphRequest::create_worker_session_called() const {
433 return request_.create_worker_session_called();
434 }
435
set_create_worker_session_called(bool called)436 void MutableProtoRunGraphRequest::set_create_worker_session_called(
437 bool called) {
438 request_.set_create_worker_session_called(called);
439 }
440
graph_handle() const441 const string& MutableProtoRunGraphRequest::graph_handle() const {
442 return request_.graph_handle();
443 }
444
set_graph_handle(const string & handle)445 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
446 request_.set_graph_handle(handle);
447 }
448
step_id() const449 int64 MutableProtoRunGraphRequest::step_id() const {
450 return request_.step_id();
451 }
452
set_step_id(int64_t step_id)453 void MutableProtoRunGraphRequest::set_step_id(int64_t step_id) {
454 request_.set_step_id(step_id);
455 }
456
exec_opts() const457 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
458 return request_.exec_opts();
459 }
460
mutable_exec_opts()461 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
462 return request_.mutable_exec_opts();
463 }
464
num_sends() const465 size_t MutableProtoRunGraphRequest::num_sends() const {
466 return request_.send_size();
467 }
468
send_key(size_t i) const469 const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
470 return request_.send(i).name();
471 }
472
SendValue(size_t i,Tensor * out_tensor) const473 Status MutableProtoRunGraphRequest::SendValue(size_t i,
474 Tensor* out_tensor) const {
475 if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
476 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
477 } else {
478 return Status::OK();
479 }
480 }
481
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)482 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
483 const RunStepRequestWrapper& run_step_request, size_t i,
484 const string& send_key) {
485 NamedTensorProto* send = request_.add_send();
486 send->set_name(send_key);
487 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
488 return Status::OK();
489 }
490
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)491 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
492 const RunCallableRequest& run_callable_request, size_t i,
493 const string& send_key) {
494 NamedTensorProto* send = request_.add_send();
495 send->set_name(send_key);
496 *send->mutable_tensor() = run_callable_request.feed(i);
497 return Status::OK();
498 }
499
num_recvs() const500 size_t MutableProtoRunGraphRequest::num_recvs() const {
501 return request_.recv_key_size();
502 }
503
recv_key(size_t i) const504 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
505 return request_.recv_key(i);
506 }
507
add_recv_key(const string & recv_key)508 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
509 request_.add_recv_key(recv_key);
510 }
511
is_partial() const512 bool MutableProtoRunGraphRequest::is_partial() const {
513 return request_.is_partial();
514 }
515
set_is_partial(bool is_partial)516 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
517 request_.set_is_partial(is_partial);
518 }
519
is_last_partial_run() const520 bool MutableProtoRunGraphRequest::is_last_partial_run() const {
521 return request_.is_last_partial_run();
522 }
523
set_is_last_partial_run(bool is_last_partial_run)524 void MutableProtoRunGraphRequest::set_is_last_partial_run(
525 bool is_last_partial_run) {
526 request_.set_is_last_partial_run(is_last_partial_run);
527 }
528
store_errors_in_response_body() const529 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
530 return request_.store_errors_in_response_body();
531 }
532
set_store_errors_in_response_body(bool store_errors)533 void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
534 bool store_errors) {
535 request_.set_store_errors_in_response_body(store_errors);
536 }
537
request_id() const538 int64 MutableProtoRunGraphRequest::request_id() const {
539 return request_.request_id();
540 }
541
set_request_id(int64_t request_id)542 void MutableProtoRunGraphRequest::set_request_id(int64_t request_id) {
543 request_.set_request_id(request_id);
544 }
545
ToProto() const546 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
547 return request_;
548 }
549
ProtoRunGraphRequest(const RunGraphRequest * request)550 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
551 : request_(request) {}
552
session_handle() const553 const string& ProtoRunGraphRequest::session_handle() const {
554 return request_->session_handle();
555 }
556
create_worker_session_called() const557 bool ProtoRunGraphRequest::create_worker_session_called() const {
558 return request_->create_worker_session_called();
559 }
560
graph_handle() const561 const string& ProtoRunGraphRequest::graph_handle() const {
562 return request_->graph_handle();
563 }
564
step_id() const565 int64 ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
566
exec_opts() const567 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
568 return request_->exec_opts();
569 }
570
num_sends() const571 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
572
send_key(size_t i) const573 const string& ProtoRunGraphRequest::send_key(size_t i) const {
574 return request_->send(i).name();
575 }
576
SendValue(size_t i,Tensor * out_tensor) const577 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
578 if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
579 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
580 } else {
581 return Status::OK();
582 }
583 }
584
num_recvs() const585 size_t ProtoRunGraphRequest::num_recvs() const {
586 return request_->recv_key_size();
587 }
588
recv_key(size_t i) const589 const string& ProtoRunGraphRequest::recv_key(size_t i) const {
590 return request_->recv_key(i);
591 }
592
is_partial() const593 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
594
is_last_partial_run() const595 bool ProtoRunGraphRequest::is_last_partial_run() const {
596 return request_->is_last_partial_run();
597 }
598
store_errors_in_response_body() const599 bool ProtoRunGraphRequest::store_errors_in_response_body() const {
600 return request_->store_errors_in_response_body();
601 }
602
request_id() const603 int64 ProtoRunGraphRequest::request_id() const {
604 return request_->request_id();
605 }
606
ToProto() const607 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
608 return *request_;
609 }
610
num_recvs() const611 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
612
recv_key(size_t i) const613 const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
614 return recvs_[i].first;
615 }
616
RecvValue(size_t i,TensorProto * out_tensor)617 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
618 recvs_[i].second.AsProtoTensorContent(out_tensor);
619 return Status::OK();
620 }
621
RecvValue(size_t i,Tensor * out_tensor)622 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
623 *out_tensor = recvs_[i].second;
624 return Status::OK();
625 }
626
AddRecv(const string & key,const Tensor & value)627 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
628 recvs_.emplace_back(key, value);
629 }
630
mutable_step_stats()631 StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
632 return &step_stats_;
633 }
634
mutable_cost_graph()635 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
636 return &cost_graph_;
637 }
638
status_code() const639 errors::Code InMemoryRunGraphResponse::status_code() const {
640 return status_.code();
641 }
642
status_error_message() const643 const string& InMemoryRunGraphResponse::status_error_message() const {
644 return status_.error_message();
645 }
646
set_status(const Status & status)647 void InMemoryRunGraphResponse::set_status(const Status& status) {
648 status_ = status;
649 }
650
get_proto()651 RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
652 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
653 return nullptr;
654 }
655
num_partition_graphs() const656 size_t InMemoryRunGraphResponse::num_partition_graphs() const {
657 return partition_graphs_.size();
658 }
659
mutable_partition_graph(size_t i)660 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
661 return &partition_graphs_[i];
662 }
663
AddPartitionGraph(const GraphDef & partition_graph)664 void InMemoryRunGraphResponse::AddPartitionGraph(
665 const GraphDef& partition_graph) {
666 partition_graphs_.push_back(partition_graph);
667 }
668
num_recvs() const669 size_t OwnedProtoRunGraphResponse::num_recvs() const {
670 return response_.recv_size();
671 }
672
recv_key(size_t i) const673 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
674 return response_.recv(i).name();
675 }
676
RecvValue(size_t i,TensorProto * out_tensor)677 Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
678 TensorProto* out_tensor) {
679 out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
680 return Status::OK();
681 }
682
RecvValue(size_t i,Tensor * out_tensor)683 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
684 if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
685 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
686 } else {
687 return Status::OK();
688 }
689 }
690
AddRecv(const string & key,const Tensor & value)691 void OwnedProtoRunGraphResponse::AddRecv(const string& key,
692 const Tensor& value) {
693 NamedTensorProto* recv = response_.add_recv();
694 recv->set_name(key);
695 TensorProto* value_proto = recv->mutable_tensor();
696 value.AsProtoTensorContent(value_proto);
697 }
698
mutable_step_stats()699 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
700 return response_.mutable_step_stats();
701 }
702
mutable_cost_graph()703 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
704 return response_.mutable_cost_graph();
705 }
706
status_code() const707 errors::Code OwnedProtoRunGraphResponse::status_code() const {
708 return response_.status_code();
709 }
710
status_error_message() const711 const string& OwnedProtoRunGraphResponse::status_error_message() const {
712 return response_.status_error_message();
713 }
714
set_status(const Status & status)715 void OwnedProtoRunGraphResponse::set_status(const Status& status) {
716 response_.set_status_code(status.code());
717 response_.set_status_error_message(status.error_message());
718 }
719
get_proto()720 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
721
num_partition_graphs() const722 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
723 return response_.partition_graph_size();
724 }
725
mutable_partition_graph(size_t i)726 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
727 return response_.mutable_partition_graph(i);
728 }
729
AddPartitionGraph(const GraphDef & partition_graph)730 void OwnedProtoRunGraphResponse::AddPartitionGraph(
731 const GraphDef& partition_graph) {
732 GraphDef* graph_def = response_.mutable_partition_graph()->Add();
733 *graph_def = partition_graph;
734 }
735
NonOwnedProtoRunGraphResponse(RunGraphResponse * response)736 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
737 RunGraphResponse* response)
738 : response_(response) {}
739
num_recvs() const740 size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
741 return response_->recv_size();
742 }
743
recv_key(size_t i) const744 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
745 return response_->recv(i).name();
746 }
747
RecvValue(size_t i,TensorProto * out_tensor)748 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
749 TensorProto* out_tensor) {
750 out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
751 return Status::OK();
752 }
753
RecvValue(size_t i,Tensor * out_tensor)754 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
755 if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
756 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
757 } else {
758 return Status::OK();
759 }
760 }
761
AddRecv(const string & key,const Tensor & value)762 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
763 const Tensor& value) {
764 NamedTensorProto* recv = response_->add_recv();
765 recv->set_name(key);
766 TensorProto* value_proto = recv->mutable_tensor();
767 value.AsProtoTensorContent(value_proto);
768 }
769
mutable_step_stats()770 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
771 return response_->mutable_step_stats();
772 }
773
mutable_cost_graph()774 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
775 return response_->mutable_cost_graph();
776 }
777
status_code() const778 errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
779 return response_->status_code();
780 }
781
status_error_message() const782 const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
783 return response_->status_error_message();
784 }
785
set_status(const Status & status)786 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
787 response_->set_status_code(status.code());
788 response_->set_status_error_message(status.error_message());
789 }
790
get_proto()791 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
792 return response_;
793 }
794
num_partition_graphs() const795 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
796 return response_->partition_graph_size();
797 }
798
mutable_partition_graph(size_t i)799 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
800 return response_->mutable_partition_graph(i);
801 }
802
AddPartitionGraph(const GraphDef & partition_graph)803 void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
804 const GraphDef& partition_graph) {
805 GraphDef* graph_def = response_->add_partition_graph();
806 *graph_def = partition_graph;
807 }
808
~MutableRunStepResponseWrapper()809 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
810
num_tensors() const811 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
812
tensor_name(size_t i) const813 const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
814 return tensors_[i].first;
815 }
816
TensorValue(size_t i,Tensor * out_tensor) const817 Status InMemoryRunStepResponse::TensorValue(size_t i,
818 Tensor* out_tensor) const {
819 *out_tensor = tensors_[i].second;
820 return Status::OK();
821 }
822
metadata() const823 const RunMetadata& InMemoryRunStepResponse::metadata() const {
824 return metadata_;
825 }
826
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * wrapper,size_t i)827 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
828 const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
829 Tensor tensor;
830 TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
831 tensors_.emplace_back(name, tensor);
832 return Status::OK();
833 }
834
mutable_metadata()835 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
836
status_code() const837 errors::Code InMemoryRunStepResponse::status_code() const {
838 return status_.code();
839 }
840
status_error_message() const841 const string& InMemoryRunStepResponse::status_error_message() const {
842 return status_.error_message();
843 }
844
set_status(const Status & status)845 void InMemoryRunStepResponse::set_status(const Status& status) {
846 status_ = status;
847 }
848
get_proto()849 RunStepResponse* InMemoryRunStepResponse::get_proto() {
850 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
851 return nullptr;
852 }
853
num_tensors() const854 size_t OwnedProtoRunStepResponse::num_tensors() const {
855 return response_.tensor_size();
856 }
857
tensor_name(size_t i) const858 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
859 return response_.tensor(i).name();
860 }
861
TensorValue(size_t i,Tensor * out_tensor) const862 Status OwnedProtoRunStepResponse::TensorValue(size_t i,
863 Tensor* out_tensor) const {
864 if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
865 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
866 } else {
867 return Status::OK();
868 }
869 }
870
metadata() const871 const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
872 return response_.metadata();
873 }
874
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)875 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
876 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
877 size_t i) {
878 NamedTensorProto* response_tensor = response_.add_tensor();
879 response_tensor->set_name(name);
880 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
881 }
882
mutable_metadata()883 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
884 return response_.mutable_metadata();
885 }
886
status_code() const887 errors::Code OwnedProtoRunStepResponse::status_code() const {
888 return response_.status_code();
889 }
890
status_error_message() const891 const string& OwnedProtoRunStepResponse::status_error_message() const {
892 return response_.status_error_message();
893 }
894
set_status(const Status & status)895 void OwnedProtoRunStepResponse::set_status(const Status& status) {
896 response_.set_status_code(status.code());
897 response_.set_status_error_message(status.error_message());
898 }
899
get_proto()900 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
901
NonOwnedProtoRunStepResponse(RunStepResponse * response)902 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
903 RunStepResponse* response)
904 : response_(response) {}
905
num_tensors() const906 size_t NonOwnedProtoRunStepResponse::num_tensors() const {
907 return response_->tensor_size();
908 }
909
tensor_name(size_t i) const910 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
911 return response_->tensor(i).name();
912 }
913
TensorValue(size_t i,Tensor * out_tensor) const914 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
915 Tensor* out_tensor) const {
916 if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
917 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
918 } else {
919 return Status::OK();
920 }
921 }
922
metadata() const923 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
924 return response_->metadata();
925 }
926
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)927 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
928 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
929 size_t i) {
930 NamedTensorProto* response_tensor = response_->add_tensor();
931 response_tensor->set_name(name);
932 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
933 }
934
mutable_metadata()935 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
936 return response_->mutable_metadata();
937 }
938
status_code() const939 errors::Code NonOwnedProtoRunStepResponse::status_code() const {
940 return response_->status_code();
941 }
942
status_error_message() const943 const string& NonOwnedProtoRunStepResponse::status_error_message() const {
944 return response_->status_error_message();
945 }
946
set_status(const Status & status)947 void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
948 response_->set_status_code(status.code());
949 response_->set_status_error_message(status.error_message());
950 }
951
get_proto()952 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
953
954 } // namespace tensorflow
955