1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
17
18 #include "tensorflow/core/framework/cost_graph.pb.h"
19 #include "tensorflow/core/framework/step_stats.pb.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/protobuf/config.pb.h"
22 #include "tensorflow/core/protobuf/named_tensor.pb.h"
23
24 namespace tensorflow {
25
ParseTensorProtoToTensor(const TensorProto & tensor_proto,Tensor * out_tensor)26 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
27 Tensor* out_tensor) {
28 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
29 Tensor parsed(tensor_proto.dtype());
30 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
31 *out_tensor = parsed;
32 return true;
33 }
34 }
35 return false;
36 }
37
session_handle() const38 const string& InMemoryRunStepRequest::session_handle() const {
39 return session_handle_;
40 }
41
set_session_handle(const string & handle)42 void InMemoryRunStepRequest::set_session_handle(const string& handle) {
43 session_handle_ = handle;
44 }
45
partial_run_handle() const46 const string& InMemoryRunStepRequest::partial_run_handle() const {
47 return partial_run_handle_;
48 }
49
set_partial_run_handle(const string & handle)50 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
51 partial_run_handle_ = handle;
52 }
53
num_feeds() const54 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
feed_name(size_t i) const55 const string& InMemoryRunStepRequest::feed_name(size_t i) const {
56 return feeds_[i].first;
57 }
58
FeedValue(size_t i,Tensor * out_tensor) const59 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
60 *out_tensor = feeds_[i].second;
61 return Status::OK();
62 }
63
FeedValue(size_t i,TensorProto * out_tensor) const64 Status InMemoryRunStepRequest::FeedValue(size_t i,
65 TensorProto* out_tensor) const {
66 feeds_[i].second.AsProtoTensorContent(out_tensor);
67 return Status::OK();
68 }
69
add_feed(const string & name,const Tensor & value)70 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
71 feeds_.emplace_back(name, value);
72 }
73
num_fetches() const74 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
fetch_name(size_t i) const75 const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
76 return fetches_[i];
77 }
add_fetch(const string & name)78 void InMemoryRunStepRequest::add_fetch(const string& name) {
79 fetches_.push_back(name);
80 }
81
num_targets() const82 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
target_name(size_t i) const83 const string& InMemoryRunStepRequest::target_name(size_t i) const {
84 return targets_[i];
85 }
add_target(const string & name)86 void InMemoryRunStepRequest::add_target(const string& name) {
87 targets_.push_back(name);
88 }
89
options() const90 const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
91
mutable_options()92 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
93
store_errors_in_response_body() const94 bool InMemoryRunStepRequest::store_errors_in_response_body() const {
95 return store_errors_in_response_body_;
96 }
97
request_id() const98 int64 InMemoryRunStepRequest::request_id() const {
99 return 0; // no need to track request id for local version.
100 }
101
set_store_errors_in_response_body(bool store_errors)102 void InMemoryRunStepRequest::set_store_errors_in_response_body(
103 bool store_errors) {
104 store_errors_in_response_body_ = store_errors;
105 }
106
DebugString() const107 string InMemoryRunStepRequest::DebugString() const {
108 return ToProto().DebugString();
109 }
110
ToProto() const111 const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
112 if (!proto_version_) {
113 proto_version_.reset(new RunStepRequest);
114 proto_version_->set_session_handle(session_handle());
115 proto_version_->set_partial_run_handle(partial_run_handle());
116 for (size_t i = 0; i < num_feeds(); ++i) {
117 auto feed = proto_version_->add_feed();
118 feed->set_name(feed_name(i));
119 feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
120 }
121 for (size_t i = 0; i < num_fetches(); ++i) {
122 proto_version_->add_fetch(fetch_name(i));
123 }
124 for (size_t i = 0; i < num_targets(); ++i) {
125 proto_version_->add_target(target_name(i));
126 }
127 *proto_version_->mutable_options() = options();
128 }
129 return *proto_version_;
130 }
131
session_handle() const132 const string& MutableProtoRunStepRequest::session_handle() const {
133 return request_.session_handle();
134 }
set_session_handle(const string & handle)135 void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
136 request_.set_session_handle(handle);
137 }
138
partial_run_handle() const139 const string& MutableProtoRunStepRequest::partial_run_handle() const {
140 return request_.partial_run_handle();
141 }
set_partial_run_handle(const string & handle)142 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
143 request_.set_partial_run_handle(handle);
144 }
145
num_feeds() const146 size_t MutableProtoRunStepRequest::num_feeds() const {
147 return request_.feed_size();
148 }
feed_name(size_t i) const149 const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
150 return request_.feed(i).name();
151 }
FeedValue(size_t i,Tensor * out_tensor) const152 Status MutableProtoRunStepRequest::FeedValue(size_t i,
153 Tensor* out_tensor) const {
154 if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
155 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
156 } else {
157 return Status::OK();
158 }
159 }
160
FeedValue(size_t i,TensorProto * out_tensor) const161 Status MutableProtoRunStepRequest::FeedValue(size_t i,
162 TensorProto* out_tensor) const {
163 *out_tensor = request_.feed(i).tensor();
164 return Status::OK();
165 }
166
add_feed(const string & name,const Tensor & value)167 void MutableProtoRunStepRequest::add_feed(const string& name,
168 const Tensor& value) {
169 NamedTensorProto* feed = request_.add_feed();
170 feed->set_name(name);
171 TensorProto* value_proto = feed->mutable_tensor();
172 value.AsProtoTensorContent(value_proto);
173 }
174
num_fetches() const175 size_t MutableProtoRunStepRequest::num_fetches() const {
176 return request_.fetch_size();
177 }
178
fetch_name(size_t i) const179 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
180 return request_.fetch(i);
181 }
add_fetch(const string & name)182 void MutableProtoRunStepRequest::add_fetch(const string& name) {
183 request_.add_fetch(name);
184 }
185
num_targets() const186 size_t MutableProtoRunStepRequest::num_targets() const {
187 return request_.target_size();
188 }
189
target_name(size_t i) const190 const string& MutableProtoRunStepRequest::target_name(size_t i) const {
191 return request_.target(i);
192 }
193
add_target(const string & name)194 void MutableProtoRunStepRequest::add_target(const string& name) {
195 request_.add_target(name);
196 }
197
options() const198 const RunOptions& MutableProtoRunStepRequest::options() const {
199 return request_.options();
200 }
201
mutable_options()202 RunOptions* MutableProtoRunStepRequest::mutable_options() {
203 return request_.mutable_options();
204 }
205
store_errors_in_response_body() const206 bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
207 return request_.store_errors_in_response_body();
208 }
209
set_store_errors_in_response_body(bool store_errors)210 void MutableProtoRunStepRequest::set_store_errors_in_response_body(
211 bool store_errors) {
212 request_.set_store_errors_in_response_body(store_errors);
213 }
214
request_id() const215 int64 MutableProtoRunStepRequest::request_id() const {
216 return request_.request_id();
217 }
218
DebugString() const219 string MutableProtoRunStepRequest::DebugString() const {
220 return request_.DebugString();
221 }
222
ToProto() const223 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
224 return request_;
225 }
226
ProtoRunStepRequest(const RunStepRequest * request)227 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
228 : request_(request) {}
229
session_handle() const230 const string& ProtoRunStepRequest::session_handle() const {
231 return request_->session_handle();
232 }
233
partial_run_handle() const234 const string& ProtoRunStepRequest::partial_run_handle() const {
235 return request_->partial_run_handle();
236 }
237
num_feeds() const238 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
239
feed_name(size_t i) const240 const string& ProtoRunStepRequest::feed_name(size_t i) const {
241 return request_->feed(i).name();
242 }
243
FeedValue(size_t i,Tensor * out_tensor) const244 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
245 if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
246 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
247 } else {
248 return Status::OK();
249 }
250 }
251
FeedValue(size_t i,TensorProto * out_tensor) const252 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
253 *out_tensor = request_->feed(i).tensor();
254 return Status::OK();
255 }
256
num_fetches() const257 size_t ProtoRunStepRequest::num_fetches() const {
258 return request_->fetch_size();
259 }
260
fetch_name(size_t i) const261 const string& ProtoRunStepRequest::fetch_name(size_t i) const {
262 return request_->fetch(i);
263 }
264
num_targets() const265 size_t ProtoRunStepRequest::num_targets() const {
266 return request_->target_size();
267 }
268
target_name(size_t i) const269 const string& ProtoRunStepRequest::target_name(size_t i) const {
270 return request_->target(i);
271 }
272
options() const273 const RunOptions& ProtoRunStepRequest::options() const {
274 return request_->options();
275 }
276
store_errors_in_response_body() const277 bool ProtoRunStepRequest::store_errors_in_response_body() const {
278 return request_->store_errors_in_response_body();
279 }
280
request_id() const281 int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
282
DebugString() const283 string ProtoRunStepRequest::DebugString() const {
284 return request_->DebugString();
285 }
286
ToProto() const287 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
288
session_handle() const289 const string& InMemoryRunGraphRequest::session_handle() const {
290 return session_handle_;
291 }
292
create_worker_session_called() const293 bool InMemoryRunGraphRequest::create_worker_session_called() const {
294 return create_worker_session_called_;
295 }
296
set_session_handle(const string & handle)297 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
298 session_handle_ = handle;
299 }
300
set_create_worker_session_called(bool called)301 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
302 create_worker_session_called_ = called;
303 }
304
graph_handle() const305 const string& InMemoryRunGraphRequest::graph_handle() const {
306 return graph_handle_;
307 }
308
set_graph_handle(const string & handle)309 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
310 graph_handle_ = handle;
311 }
312
step_id() const313 int64 InMemoryRunGraphRequest::step_id() const { return step_id_; }
314
set_step_id(int64 step_id)315 void InMemoryRunGraphRequest::set_step_id(int64 step_id) { step_id_ = step_id; }
316
exec_opts() const317 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
318 return exec_opts_;
319 }
320
mutable_exec_opts()321 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
322 return &exec_opts_;
323 }
324
num_sends() const325 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
326
send_key(size_t i) const327 const string& InMemoryRunGraphRequest::send_key(size_t i) const {
328 return sends_[i].first;
329 }
330
SendValue(size_t i,Tensor * out_tensor) const331 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
332 *out_tensor = sends_[i].second;
333 return Status::OK();
334 }
335
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)336 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
337 const RunStepRequestWrapper& run_step_request, size_t i,
338 const string& send_key) {
339 Tensor tensor;
340 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
341 sends_.emplace_back(send_key, std::move(tensor));
342 return Status::OK();
343 }
344
345 // TODO(b/74355905): Add a specialized implementation that avoids
346 // copying the tensor when at least two of the {client, master,
347 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)348 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
349 const RunCallableRequest& run_callable_request, size_t i,
350 const string& send_key) {
351 Tensor tensor;
352 if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
353 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
354 }
355 sends_.emplace_back(send_key, std::move(tensor));
356 return Status::OK();
357 }
358
num_recvs() const359 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
360
recv_key(size_t i) const361 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
362 return recvs_[i];
363 }
364
add_recv_key(const string & recv_key)365 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
366 recvs_.push_back(recv_key);
367 }
368
is_partial() const369 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
370
set_is_partial(bool is_partial)371 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
372 is_partial_ = is_partial;
373 }
374
is_last_partial_run() const375 bool InMemoryRunGraphRequest::is_last_partial_run() const {
376 return is_last_partial_run_;
377 }
378
set_is_last_partial_run(bool is_last_partial_run)379 void InMemoryRunGraphRequest::set_is_last_partial_run(
380 bool is_last_partial_run) {
381 is_last_partial_run_ = is_last_partial_run;
382 }
383
store_errors_in_response_body() const384 bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
385 return store_errors_in_response_body_;
386 }
387
set_store_errors_in_response_body(bool store_errors)388 void InMemoryRunGraphRequest::set_store_errors_in_response_body(
389 bool store_errors) {
390 store_errors_in_response_body_ = store_errors;
391 }
392
request_id() const393 int64 InMemoryRunGraphRequest::request_id() const { return request_id_; }
394
set_request_id(int64 request_id)395 void InMemoryRunGraphRequest::set_request_id(int64 request_id) {
396 request_id_ = request_id;
397 }
398
ToProto() const399 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
400 if (!proto_version_) {
401 proto_version_.reset(new RunGraphRequest);
402 proto_version_->set_session_handle(session_handle());
403 proto_version_->set_create_worker_session_called(
404 create_worker_session_called());
405 proto_version_->set_graph_handle(graph_handle());
406 proto_version_->set_step_id(step_id());
407 *proto_version_->mutable_exec_opts() = exec_opts();
408 for (size_t i = 0; i < num_sends(); ++i) {
409 auto send = proto_version_->add_send();
410 send->set_name(send_key(i));
411 sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
412 }
413 for (size_t i = 0; i < num_recvs(); ++i) {
414 proto_version_->add_recv_key(recv_key(i));
415 }
416 proto_version_->set_is_partial(is_partial());
417 proto_version_->set_is_last_partial_run(is_last_partial_run());
418 }
419 proto_version_->set_store_errors_in_response_body(
420 store_errors_in_response_body_);
421 proto_version_->set_request_id(request_id_);
422 return *proto_version_;
423 }
424
session_handle() const425 const string& MutableProtoRunGraphRequest::session_handle() const {
426 return request_.session_handle();
427 }
428
set_session_handle(const string & handle)429 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
430 request_.set_session_handle(handle);
431 }
432
create_worker_session_called() const433 bool MutableProtoRunGraphRequest::create_worker_session_called() const {
434 return request_.create_worker_session_called();
435 }
436
set_create_worker_session_called(bool called)437 void MutableProtoRunGraphRequest::set_create_worker_session_called(
438 bool called) {
439 request_.set_create_worker_session_called(called);
440 }
441
graph_handle() const442 const string& MutableProtoRunGraphRequest::graph_handle() const {
443 return request_.graph_handle();
444 }
445
set_graph_handle(const string & handle)446 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
447 request_.set_graph_handle(handle);
448 }
449
step_id() const450 int64 MutableProtoRunGraphRequest::step_id() const {
451 return request_.step_id();
452 }
453
set_step_id(int64 step_id)454 void MutableProtoRunGraphRequest::set_step_id(int64 step_id) {
455 request_.set_step_id(step_id);
456 }
457
exec_opts() const458 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
459 return request_.exec_opts();
460 }
461
mutable_exec_opts()462 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
463 return request_.mutable_exec_opts();
464 }
465
num_sends() const466 size_t MutableProtoRunGraphRequest::num_sends() const {
467 return request_.send_size();
468 }
469
send_key(size_t i) const470 const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
471 return request_.send(i).name();
472 }
473
SendValue(size_t i,Tensor * out_tensor) const474 Status MutableProtoRunGraphRequest::SendValue(size_t i,
475 Tensor* out_tensor) const {
476 if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
477 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
478 } else {
479 return Status::OK();
480 }
481 }
482
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)483 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
484 const RunStepRequestWrapper& run_step_request, size_t i,
485 const string& send_key) {
486 NamedTensorProto* send = request_.add_send();
487 send->set_name(send_key);
488 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
489 return Status::OK();
490 }
491
492 // TODO(b/74355905): Add a specialized implementation that avoids
493 // copying the tensor when at least two of the {client, master,
494 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)495 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
496 const RunCallableRequest& run_callable_request, size_t i,
497 const string& send_key) {
498 NamedTensorProto* send = request_.add_send();
499 send->set_name(send_key);
500 *send->mutable_tensor() = run_callable_request.feed(i);
501 return Status::OK();
502 }
503
num_recvs() const504 size_t MutableProtoRunGraphRequest::num_recvs() const {
505 return request_.recv_key_size();
506 }
507
recv_key(size_t i) const508 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
509 return request_.recv_key(i);
510 }
511
add_recv_key(const string & recv_key)512 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
513 request_.add_recv_key(recv_key);
514 }
515
is_partial() const516 bool MutableProtoRunGraphRequest::is_partial() const {
517 return request_.is_partial();
518 }
519
set_is_partial(bool is_partial)520 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
521 request_.set_is_partial(is_partial);
522 }
523
is_last_partial_run() const524 bool MutableProtoRunGraphRequest::is_last_partial_run() const {
525 return request_.is_last_partial_run();
526 }
527
set_is_last_partial_run(bool is_last_partial_run)528 void MutableProtoRunGraphRequest::set_is_last_partial_run(
529 bool is_last_partial_run) {
530 request_.set_is_last_partial_run(is_last_partial_run);
531 }
532
store_errors_in_response_body() const533 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
534 return request_.store_errors_in_response_body();
535 }
536
set_store_errors_in_response_body(bool store_errors)537 void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
538 bool store_errors) {
539 request_.set_store_errors_in_response_body(store_errors);
540 }
541
request_id() const542 int64 MutableProtoRunGraphRequest::request_id() const {
543 return request_.request_id();
544 }
545
set_request_id(int64 request_id)546 void MutableProtoRunGraphRequest::set_request_id(int64 request_id) {
547 request_.set_request_id(request_id);
548 }
549
ToProto() const550 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
551 return request_;
552 }
553
ProtoRunGraphRequest(const RunGraphRequest * request)554 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
555 : request_(request) {}
556
session_handle() const557 const string& ProtoRunGraphRequest::session_handle() const {
558 return request_->session_handle();
559 }
560
create_worker_session_called() const561 bool ProtoRunGraphRequest::create_worker_session_called() const {
562 return request_->create_worker_session_called();
563 }
564
graph_handle() const565 const string& ProtoRunGraphRequest::graph_handle() const {
566 return request_->graph_handle();
567 }
568
step_id() const569 int64 ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
570
exec_opts() const571 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
572 return request_->exec_opts();
573 }
574
num_sends() const575 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
576
send_key(size_t i) const577 const string& ProtoRunGraphRequest::send_key(size_t i) const {
578 return request_->send(i).name();
579 }
580
SendValue(size_t i,Tensor * out_tensor) const581 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
582 if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
583 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
584 } else {
585 return Status::OK();
586 }
587 }
588
num_recvs() const589 size_t ProtoRunGraphRequest::num_recvs() const {
590 return request_->recv_key_size();
591 }
592
recv_key(size_t i) const593 const string& ProtoRunGraphRequest::recv_key(size_t i) const {
594 return request_->recv_key(i);
595 }
596
is_partial() const597 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
598
is_last_partial_run() const599 bool ProtoRunGraphRequest::is_last_partial_run() const {
600 return request_->is_last_partial_run();
601 }
602
store_errors_in_response_body() const603 bool ProtoRunGraphRequest::store_errors_in_response_body() const {
604 return request_->store_errors_in_response_body();
605 }
606
request_id() const607 int64 ProtoRunGraphRequest::request_id() const {
608 return request_->request_id();
609 }
610
ToProto() const611 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
612 return *request_;
613 }
614
num_recvs() const615 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
616
recv_key(size_t i) const617 const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
618 return recvs_[i].first;
619 }
620
RecvValue(size_t i,TensorProto * out_tensor)621 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
622 recvs_[i].second.AsProtoTensorContent(out_tensor);
623 return Status::OK();
624 }
625
RecvValue(size_t i,Tensor * out_tensor)626 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
627 *out_tensor = recvs_[i].second;
628 return Status::OK();
629 }
630
AddRecv(const string & key,const Tensor & value)631 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
632 recvs_.emplace_back(key, value);
633 }
634
mutable_step_stats()635 StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
636 return &step_stats_;
637 }
638
mutable_cost_graph()639 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
640 return &cost_graph_;
641 }
642
status_code() const643 errors::Code InMemoryRunGraphResponse::status_code() const {
644 return status_.code();
645 }
646
status_error_message() const647 const string& InMemoryRunGraphResponse::status_error_message() const {
648 return status_.error_message();
649 }
650
set_status(const Status & status)651 void InMemoryRunGraphResponse::set_status(const Status& status) {
652 status_ = status;
653 }
654
get_proto()655 RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
656 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
657 return nullptr;
658 }
659
num_partition_graphs() const660 size_t InMemoryRunGraphResponse::num_partition_graphs() const {
661 return partition_graphs_.size();
662 }
663
mutable_partition_graph(size_t i)664 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
665 return &partition_graphs_[i];
666 }
667
AddPartitionGraph(const GraphDef & partition_graph)668 void InMemoryRunGraphResponse::AddPartitionGraph(
669 const GraphDef& partition_graph) {
670 partition_graphs_.push_back(partition_graph);
671 }
672
num_recvs() const673 size_t OwnedProtoRunGraphResponse::num_recvs() const {
674 return response_.recv_size();
675 }
676
recv_key(size_t i) const677 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
678 return response_.recv(i).name();
679 }
680
RecvValue(size_t i,TensorProto * out_tensor)681 Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
682 TensorProto* out_tensor) {
683 out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
684 return Status::OK();
685 }
686
RecvValue(size_t i,Tensor * out_tensor)687 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
688 if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
689 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
690 } else {
691 return Status::OK();
692 }
693 }
694
AddRecv(const string & key,const Tensor & value)695 void OwnedProtoRunGraphResponse::AddRecv(const string& key,
696 const Tensor& value) {
697 NamedTensorProto* recv = response_.add_recv();
698 recv->set_name(key);
699 TensorProto* value_proto = recv->mutable_tensor();
700 value.AsProtoTensorContent(value_proto);
701 }
702
mutable_step_stats()703 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
704 return response_.mutable_step_stats();
705 }
706
mutable_cost_graph()707 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
708 return response_.mutable_cost_graph();
709 }
710
status_code() const711 errors::Code OwnedProtoRunGraphResponse::status_code() const {
712 return response_.status_code();
713 }
714
status_error_message() const715 const string& OwnedProtoRunGraphResponse::status_error_message() const {
716 return response_.status_error_message();
717 }
718
set_status(const Status & status)719 void OwnedProtoRunGraphResponse::set_status(const Status& status) {
720 response_.set_status_code(status.code());
721 response_.set_status_error_message(status.error_message());
722 }
723
get_proto()724 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
725
num_partition_graphs() const726 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
727 return response_.partition_graph_size();
728 }
729
mutable_partition_graph(size_t i)730 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
731 return response_.mutable_partition_graph(i);
732 }
733
AddPartitionGraph(const GraphDef & partition_graph)734 void OwnedProtoRunGraphResponse::AddPartitionGraph(
735 const GraphDef& partition_graph) {
736 GraphDef* graph_def = response_.mutable_partition_graph()->Add();
737 *graph_def = partition_graph;
738 }
739
NonOwnedProtoRunGraphResponse(RunGraphResponse * response)740 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
741 RunGraphResponse* response)
742 : response_(response) {}
743
num_recvs() const744 size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
745 return response_->recv_size();
746 }
747
recv_key(size_t i) const748 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
749 return response_->recv(i).name();
750 }
751
RecvValue(size_t i,TensorProto * out_tensor)752 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
753 TensorProto* out_tensor) {
754 out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
755 return Status::OK();
756 }
757
RecvValue(size_t i,Tensor * out_tensor)758 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
759 if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
760 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
761 } else {
762 return Status::OK();
763 }
764 }
765
AddRecv(const string & key,const Tensor & value)766 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
767 const Tensor& value) {
768 NamedTensorProto* recv = response_->add_recv();
769 recv->set_name(key);
770 TensorProto* value_proto = recv->mutable_tensor();
771 value.AsProtoTensorContent(value_proto);
772 }
773
mutable_step_stats()774 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
775 return response_->mutable_step_stats();
776 }
777
mutable_cost_graph()778 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
779 return response_->mutable_cost_graph();
780 }
781
status_code() const782 errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
783 return response_->status_code();
784 }
785
status_error_message() const786 const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
787 return response_->status_error_message();
788 }
789
set_status(const Status & status)790 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
791 response_->set_status_code(status.code());
792 response_->set_status_error_message(status.error_message());
793 }
794
get_proto()795 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
796 return response_;
797 }
798
num_partition_graphs() const799 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
800 return response_->partition_graph_size();
801 }
802
mutable_partition_graph(size_t i)803 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
804 return response_->mutable_partition_graph(i);
805 }
806
AddPartitionGraph(const GraphDef & partition_graph)807 void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
808 const GraphDef& partition_graph) {
809 GraphDef* graph_def = response_->add_partition_graph();
810 *graph_def = partition_graph;
811 }
812
~MutableRunStepResponseWrapper()813 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
814
num_tensors() const815 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
816
tensor_name(size_t i) const817 const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
818 return tensors_[i].first;
819 }
820
TensorValue(size_t i,Tensor * out_tensor) const821 Status InMemoryRunStepResponse::TensorValue(size_t i,
822 Tensor* out_tensor) const {
823 *out_tensor = tensors_[i].second;
824 return Status::OK();
825 }
826
metadata() const827 const RunMetadata& InMemoryRunStepResponse::metadata() const {
828 return metadata_;
829 }
830
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * wrapper,size_t i)831 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
832 const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
833 Tensor tensor;
834 TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
835 tensors_.emplace_back(name, tensor);
836 return Status::OK();
837 }
838
mutable_metadata()839 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
840
status_code() const841 errors::Code InMemoryRunStepResponse::status_code() const {
842 return status_.code();
843 }
844
status_error_message() const845 const string& InMemoryRunStepResponse::status_error_message() const {
846 return status_.error_message();
847 }
848
set_status(const Status & status)849 void InMemoryRunStepResponse::set_status(const Status& status) {
850 status_ = status;
851 }
852
get_proto()853 RunStepResponse* InMemoryRunStepResponse::get_proto() {
854 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
855 return nullptr;
856 }
857
num_tensors() const858 size_t OwnedProtoRunStepResponse::num_tensors() const {
859 return response_.tensor_size();
860 }
861
tensor_name(size_t i) const862 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
863 return response_.tensor(i).name();
864 }
865
TensorValue(size_t i,Tensor * out_tensor) const866 Status OwnedProtoRunStepResponse::TensorValue(size_t i,
867 Tensor* out_tensor) const {
868 if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
869 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
870 } else {
871 return Status::OK();
872 }
873 }
874
metadata() const875 const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
876 return response_.metadata();
877 }
878
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)879 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
880 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
881 size_t i) {
882 NamedTensorProto* response_tensor = response_.add_tensor();
883 response_tensor->set_name(name);
884 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
885 }
886
mutable_metadata()887 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
888 return response_.mutable_metadata();
889 }
890
status_code() const891 errors::Code OwnedProtoRunStepResponse::status_code() const {
892 return response_.status_code();
893 }
894
status_error_message() const895 const string& OwnedProtoRunStepResponse::status_error_message() const {
896 return response_.status_error_message();
897 }
898
set_status(const Status & status)899 void OwnedProtoRunStepResponse::set_status(const Status& status) {
900 response_.set_status_code(status.code());
901 response_.set_status_error_message(status.error_message());
902 }
903
get_proto()904 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
905
NonOwnedProtoRunStepResponse(RunStepResponse * response)906 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
907 RunStepResponse* response)
908 : response_(response) {}
909
num_tensors() const910 size_t NonOwnedProtoRunStepResponse::num_tensors() const {
911 return response_->tensor_size();
912 }
913
tensor_name(size_t i) const914 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
915 return response_->tensor(i).name();
916 }
917
TensorValue(size_t i,Tensor * out_tensor) const918 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
919 Tensor* out_tensor) const {
920 if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
921 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
922 } else {
923 return Status::OK();
924 }
925 }
926
metadata() const927 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
928 return response_->metadata();
929 }
930
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)931 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
932 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
933 size_t i) {
934 NamedTensorProto* response_tensor = response_->add_tensor();
935 response_tensor->set_name(name);
936 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
937 }
938
mutable_metadata()939 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
940 return response_->mutable_metadata();
941 }
942
status_code() const943 errors::Code NonOwnedProtoRunStepResponse::status_code() const {
944 return response_->status_code();
945 }
946
status_error_message() const947 const string& NonOwnedProtoRunStepResponse::status_error_message() const {
948 return response_->status_error_message();
949 }
950
set_status(const Status & status)951 void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
952 response_->set_status_code(status.code());
953 response_->set_status_error_message(status.error_message());
954 }
955
get_proto()956 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
957
958 } // namespace tensorflow
959