1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 18 19 #include "tensorflow/core/framework/allocator.h" 20 #include "tensorflow/core/framework/cost_graph.pb.h" 21 #include "tensorflow/core/framework/graph.pb.h" 22 #include "tensorflow/core/framework/step_stats.pb.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor.pb.h" 25 #include "tensorflow/core/framework/versions.pb.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 #include "tensorflow/core/protobuf/master.pb.h" 28 #include "tensorflow/core/protobuf/worker.pb.h" 29 30 namespace tensorflow { 31 32 //////////////////////////////////////////////////////////////////////////////// 33 // 34 // Wrapper classes for the `MasterService.RunStep` request message. 35 // 36 // The `RunStepRequest` message can contain potentially large tensor 37 // data as part of its `feed` submessages. Here we provide specialized 38 // wrappers that avoid copying the tensor data wherever possible. 39 // 40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the 41 // protocol buffer definition. 42 // 43 //////////////////////////////////////////////////////////////////////////////// 44 45 // Abstract interface for an immutable RunStepRequest message. 46 // 47 // This interface is typically used by server-side components in the 48 // TensorFlow master. 49 class RunStepRequestWrapper { 50 public: ~RunStepRequestWrapper()51 virtual ~RunStepRequestWrapper() {} 52 53 // REQUIRED: session_handle must be returned by a CreateSession call 54 // to the same master service. 55 virtual const string& session_handle() const = 0; 56 57 // Partial run handle (optional). If specified, this will be a partial run 58 // execution, run up to the specified fetches. 59 virtual const string& partial_run_handle() const = 0; 60 61 // Tensors to be fed in the step. Each feed is a named tensor. 62 virtual size_t num_feeds() const = 0; 63 virtual const string& feed_name(size_t i) const = 0; 64 65 // Stores the content of the feed value at index `i` in `tensor`. 66 virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0; 67 virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; 68 69 // Fetches. A list of tensor names. The caller expects a tensor to 70 // be returned for each fetch[i] (see RunStepResponse.tensor). The 71 // order of specified fetches does not change the execution order. 72 virtual size_t num_fetches() const = 0; 73 virtual const string& fetch_name(size_t i) const = 0; 74 75 // Target Nodes. A list of node names. The named nodes will be run 76 // to but their outputs will not be fetched. 77 virtual size_t num_targets() const = 0; 78 virtual const string& target_name(size_t i) const = 0; 79 80 // Options for the run call. 81 virtual const RunOptions& options() const = 0; 82 83 // If true then some errors, e.g., execution errors that have long 84 // error messages, may return an OK RunStepResponse with the actual 85 // error saved in the status_code/status_error_message fields of the 86 // response body. This is a workaround since the RPC subsystem may 87 // truncate long metadata messages. 88 virtual bool store_errors_in_response_body() const = 0; 89 90 // Unique identifier for this request. Every RunGraphRequest must have a 91 // unique request_id, and retried RunGraphRequests must have the same 92 // request_id. If request_id is zero, retry detection is disabled. 93 virtual int64 request_id() const = 0; 94 95 // Returns a human-readable representation of this message for debugging. 96 virtual string DebugString() const = 0; 97 98 // Returns the wrapped data as a protocol buffer message. 99 virtual const RunStepRequest& ToProto() const = 0; 100 }; 101 102 // Abstract interface for a mutable RunStepRequest message. 103 // 104 // See `RunStepRequestWrapper` above for a description of the fields. 105 class MutableRunStepRequestWrapper : public RunStepRequestWrapper { 106 public: 107 virtual void set_session_handle(const string& handle) = 0; 108 virtual void set_partial_run_handle(const string& handle) = 0; 109 virtual void add_feed(const string& name, const Tensor& value) = 0; 110 virtual void add_fetch(const string& name) = 0; 111 virtual void add_target(const string& name) = 0; 112 virtual RunOptions* mutable_options() = 0; 113 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 114 }; 115 116 // Specialized (and mutable) wrapper for RunStep requests between a client and 117 // master in the same address space. 118 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { 119 public: 120 // RunStepRequestWrapper methods. 121 const string& session_handle() const override; 122 const string& partial_run_handle() const override; 123 size_t num_feeds() const override; 124 const string& feed_name(size_t i) const override; 125 Status FeedValue(size_t i, Tensor* out_tensor) const override; 126 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 127 size_t num_fetches() const override; 128 const string& fetch_name(size_t i) const override; 129 size_t num_targets() const override; 130 const string& target_name(size_t i) const override; 131 const RunOptions& options() const override; 132 string DebugString() const override; 133 const RunStepRequest& ToProto() const override; 134 bool store_errors_in_response_body() const override; 135 int64 request_id() const override; 136 137 // MutableRunStepRequestWrapper methods. 138 void set_session_handle(const string& handle) override; 139 void set_partial_run_handle(const string& handle) override; 140 void add_feed(const string& name, const Tensor& value) override; 141 void add_fetch(const string& name) override; 142 void add_target(const string& name) override; 143 RunOptions* mutable_options() override; 144 void set_store_errors_in_response_body(bool store_errors) override; 145 146 private: 147 string session_handle_; 148 string partial_run_handle_; 149 gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_; 150 gtl::InlinedVector<string, 4> fetches_; 151 gtl::InlinedVector<string, 4> targets_; 152 RunOptions options_; 153 bool store_errors_in_response_body_ = false; 154 155 // Holds a cached and owned representation of the proto 156 // representation of this request, if needed, so that `ToProto()` 157 // can return a const RunStepRequest&. 158 // NOTE(mrry): Although calls to `ToProto()` on this class are 159 // expected to be rare, retaining ownership of the returned message 160 // makes it easier to return a reference from the proto-backed 161 // representations. 162 mutable std::unique_ptr<RunStepRequest> proto_version_; 163 }; 164 165 // Wrapper for mutable RunStep requests that uses a protobuf message. 166 // 167 // This wrapper class should be used for RunStep requests between a 168 // client and master in different address spaces. 169 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { 170 public: 171 // RunStepRequestWrapper methods. 172 const string& session_handle() const override; 173 const string& partial_run_handle() const override; 174 size_t num_feeds() const override; 175 const string& feed_name(size_t i) const override; 176 Status FeedValue(size_t i, Tensor* out_tensor) const override; 177 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 178 size_t num_fetches() const override; 179 const string& fetch_name(size_t i) const override; 180 size_t num_targets() const override; 181 const string& target_name(size_t i) const override; 182 const RunOptions& options() const override; 183 string DebugString() const override; 184 const RunStepRequest& ToProto() const override; 185 bool store_errors_in_response_body() const override; 186 int64 request_id() const override; 187 188 // MutableRunStepRequestWrapper methods. 189 void set_session_handle(const string& handle) override; 190 void set_partial_run_handle(const string& handle) override; 191 void add_feed(const string& name, const Tensor& value) override; 192 void add_fetch(const string& name) override; 193 void add_target(const string& name) override; 194 RunOptions* mutable_options() override; 195 void set_store_errors_in_response_body(bool store_errors) override; 196 197 private: 198 RunStepRequest request_; 199 friend class MasterInterface; 200 }; 201 202 // Wrapper for immutable RunStep requests that use a non-owned 203 // protobuf message. 204 // 205 // This interface is typically used by server-side components in the 206 // TensorFlow master, where the incoming message is a (possibly const) 207 // `RunStepRequest*`. 208 class ProtoRunStepRequest : public RunStepRequestWrapper { 209 public: 210 ProtoRunStepRequest(const RunStepRequest* request); 211 212 // RunStepRequestWrapper methods. 213 const string& session_handle() const override; 214 const string& partial_run_handle() const override; 215 size_t num_feeds() const override; 216 const string& feed_name(size_t i) const override; 217 Status FeedValue(size_t i, Tensor* out_tensor) const override; 218 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 219 size_t num_fetches() const override; 220 const string& fetch_name(size_t i) const override; 221 size_t num_targets() const override; 222 const string& target_name(size_t i) const override; 223 const RunOptions& options() const override; 224 string DebugString() const override; 225 const RunStepRequest& ToProto() const override; 226 bool store_errors_in_response_body() const override; 227 int64 request_id() const override; 228 229 private: 230 const RunStepRequest* const request_; // Not owned. 231 }; 232 233 //////////////////////////////////////////////////////////////////////////////// 234 // 235 // Wrapper classes for the `WorkerService.RunGraph` request message. 236 // 237 // The `RunGraphRequest` message can contain potentially large tensor 238 // data as part of its `send` submessages. Here we provide specialized 239 // wrappers that avoid copying the tensor data wherever possible. 240 // 241 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the 242 // protocol buffer definition. 243 // 244 //////////////////////////////////////////////////////////////////////////////// 245 246 // Abstract interface for an immutable RunGraphRequest message. 247 // 248 // This interface is typically used by server-side components in the 249 // TensorFlow worker. 250 class RunGraphRequestWrapper { 251 public: ~RunGraphRequestWrapper()252 virtual ~RunGraphRequestWrapper() {} 253 254 // The session handle used to register the graph. If empty, a single global 255 // namespace is used. 256 virtual const string& session_handle() const = 0; 257 258 // Set to true if `CreateWorkerSession` was called for `session_handle`. 259 virtual bool create_worker_session_called() const = 0; 260 261 // REQUIRED: graph_handle must be returned by a RegisterGraph call 262 // to the same WorkerService. 263 virtual const string& graph_handle() const = 0; 264 265 // A unique ID to distinguish different runs of the same graph. 266 // 267 // The master generates a global unique `step_id` to distinguish 268 // different runs of the graph computation. Subgraphs communicate 269 // (e.g., send/recv ops) with each other using `step_id` to 270 // distinguish tensors generated by different runs. 271 virtual int64 step_id() const = 0; 272 273 // Options for this step. 274 virtual const ExecutorOpts& exec_opts() const = 0; 275 276 // Sends the tensors in "send" into the graph before the run. 277 virtual size_t num_sends() const = 0; 278 virtual const string& send_key(size_t i) const = 0; 279 virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0; 280 281 // Fetches the keys into `RunGraphResponse.recv` after the run. 282 virtual size_t num_recvs() const = 0; 283 virtual const string& recv_key(size_t i) const = 0; 284 285 // True if the RunGraphRequest is a partial run request. 286 virtual bool is_partial() const = 0; 287 288 // True if this is the last partial run request in a sequence of requests. 289 virtual bool is_last_partial_run() const = 0; 290 291 // If true then some errors, e.g., execution errors that have long 292 // error messages, may return an OK RunStepResponse with the actual 293 // error saved in the status_code/status_error_message fields of the 294 // response body. This is a workaround since the RPC subsystem may 295 // truncate long metadata messages. 296 virtual bool store_errors_in_response_body() const = 0; 297 298 virtual int64 request_id() const = 0; 299 300 // Returns the wrapped data as a protocol buffer message. 301 virtual const RunGraphRequest& ToProto() const = 0; 302 }; 303 304 // Abstract interface for a mutable RunGraphRequest message. 305 // 306 // See `RunGraphRequestWrapper` above for a description of the fields. 307 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { 308 public: 309 virtual void set_session_handle(const string& handle) = 0; 310 virtual void set_create_worker_session_called(bool called) = 0; 311 virtual void set_graph_handle(const string& handle) = 0; 312 virtual void set_step_id(int64 step_id) = 0; 313 virtual ExecutorOpts* mutable_exec_opts() = 0; 314 315 // Stores the i^{th} feed value in `run_step_request` in this 316 // request with the given `send_key`. 317 virtual Status AddSendFromRunStepRequest( 318 const RunStepRequestWrapper& run_step_request, size_t i, 319 const string& send_key) = 0; 320 virtual Status AddSendFromRunCallableRequest( 321 const RunCallableRequest& run_callable_request, size_t i, 322 const string& send_key) = 0; 323 324 virtual void add_recv_key(const string& recv_key) = 0; 325 virtual void set_is_partial(bool is_partial) = 0; 326 virtual void set_is_last_partial_run(bool is_last_partial_run) = 0; 327 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 328 virtual void set_request_id(int64 request_id) = 0; 329 }; 330 331 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { 332 public: 333 // RunGraphRequestWrapper methods. 334 const string& session_handle() const override; 335 const string& graph_handle() const override; 336 bool create_worker_session_called() const override; 337 int64 step_id() const override; 338 const ExecutorOpts& exec_opts() const override; 339 size_t num_sends() const override; 340 const string& send_key(size_t i) const override; 341 Status SendValue(size_t i, Tensor* out_tensor) const override; 342 size_t num_recvs() const override; 343 const string& recv_key(size_t i) const override; 344 bool is_partial() const override; 345 bool is_last_partial_run() const override; 346 const RunGraphRequest& ToProto() const override; 347 bool store_errors_in_response_body() const override; 348 int64 request_id() const override; 349 350 // MutableRunGraphRequestWrapper methods. 351 void set_session_handle(const string& handle) override; 352 void set_create_worker_session_called(bool called) override; 353 void set_graph_handle(const string& handle) override; 354 void set_step_id(int64 step_id) override; 355 ExecutorOpts* mutable_exec_opts() override; 356 Status AddSendFromRunStepRequest( 357 const RunStepRequestWrapper& run_step_request, size_t i, 358 const string& send_key) override; 359 Status AddSendFromRunCallableRequest( 360 const RunCallableRequest& run_callable_request, size_t i, 361 const string& send_key) override; 362 void add_recv_key(const string& recv_key) override; 363 void set_is_partial(bool is_partial) override; 364 void set_is_last_partial_run(bool is_last_partial_run) override; 365 void set_store_errors_in_response_body(bool store_errors) override; 366 void set_request_id(int64 request_id) override; 367 368 private: 369 string session_handle_; 370 bool create_worker_session_called_ = false; 371 string graph_handle_; 372 int64 step_id_; 373 ExecutorOpts exec_opts_; 374 gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_; 375 gtl::InlinedVector<string, 4> recvs_; 376 bool is_partial_ = false; 377 bool is_last_partial_run_ = false; 378 bool store_errors_in_response_body_ = false; 379 int64 request_id_ = 0; 380 381 // Holds a cached and owned representation of the proto 382 // representation of this request, if needed, so that `ToProto()` 383 // can return a const RunGraphRequest&. 384 // NOTE(mrry): Although calls to `ToProto()` on this class are 385 // expected to be rare, retaining ownership of the returned message 386 // makes it easier to return a reference from the proto-backed 387 // representations. 388 mutable std::unique_ptr<RunGraphRequest> proto_version_; 389 }; 390 391 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { 392 public: 393 // RunGraphRequestWrapper methods. 394 const string& session_handle() const override; 395 bool create_worker_session_called() const override; 396 const string& graph_handle() const override; 397 int64 step_id() const override; 398 const ExecutorOpts& exec_opts() const override; 399 size_t num_sends() const override; 400 const string& send_key(size_t i) const override; 401 Status SendValue(size_t i, Tensor* out_tensor) const override; 402 size_t num_recvs() const override; 403 const string& recv_key(size_t i) const override; 404 bool is_partial() const override; 405 bool is_last_partial_run() const override; 406 bool store_errors_in_response_body() const override; 407 int64 request_id() const override; 408 const RunGraphRequest& ToProto() const override; 409 410 // MutableRunGraphRequestWrapper methods. 411 void set_session_handle(const string& handle) override; 412 void set_create_worker_session_called(bool called) override; 413 void set_graph_handle(const string& handle) override; 414 void set_step_id(int64 step_id) override; 415 ExecutorOpts* mutable_exec_opts() override; 416 Status AddSendFromRunStepRequest( 417 const RunStepRequestWrapper& run_step_request, size_t i, 418 const string& send_key) override; 419 Status AddSendFromRunCallableRequest( 420 const RunCallableRequest& run_callable_request, size_t i, 421 const string& send_key) override; 422 void add_recv_key(const string& recv_key) override; 423 void set_is_partial(bool is_partial) override; 424 void set_is_last_partial_run(bool is_last_partial_run) override; 425 void set_store_errors_in_response_body(bool store_errors) override; 426 void set_request_id(int64 request_id) override; 427 428 private: 429 RunGraphRequest request_; 430 }; 431 432 class ProtoRunGraphRequest : public RunGraphRequestWrapper { 433 public: 434 ProtoRunGraphRequest(const RunGraphRequest* request); 435 436 // RunGraphRequestWrapper methods. 437 const string& session_handle() const override; 438 bool create_worker_session_called() const override; 439 const string& graph_handle() const override; 440 int64 step_id() const override; 441 const ExecutorOpts& exec_opts() const override; 442 size_t num_sends() const override; 443 const string& send_key(size_t i) const override; 444 Status SendValue(size_t i, Tensor* out_tensor) const override; 445 size_t num_recvs() const override; 446 const string& recv_key(size_t i) const override; 447 bool is_partial() const override; 448 bool is_last_partial_run() const override; 449 bool store_errors_in_response_body() const override; 450 int64 request_id() const override; 451 const RunGraphRequest& ToProto() const override; 452 453 private: 454 const RunGraphRequest* const request_; // Not owned. 455 }; 456 457 //////////////////////////////////////////////////////////////////////////////// 458 // 459 // Wrapper classes for the `WorkerService.RunGraph` response message. 460 // 461 // The `RunGraphResponse` message can contain potentially large tensor 462 // data as part of its `recv` submessages. Here we provide specialized 463 // wrappers that avoid copying the tensor data wherever possible. 464 // 465 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the 466 // protocol buffer definition. 467 // 468 //////////////////////////////////////////////////////////////////////////////// 469 470 // Abstract interface for a mutable RunGraphResponse message. 471 // 472 // Note that there is no corresponding (immutable) 473 // RunGraphResponseWrapper class, because the RunGraphResponse object 474 // is always used as a mutable pointer. 475 class MutableRunGraphResponseWrapper { 476 public: ~MutableRunGraphResponseWrapper()477 virtual ~MutableRunGraphResponseWrapper() {} 478 479 // A list of tensors corresponding to those requested by 480 // `RunGraphRequest.recv_key`. 481 virtual size_t num_recvs() const = 0; 482 virtual const string& recv_key(size_t i) const = 0; 483 // NOTE: The following methods may perform a destructive read, for 484 // efficiency. 485 virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0; 486 virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0; 487 virtual void AddRecv(const string& key, const Tensor& value) = 0; 488 489 // Submessages that store performance statistics about the subgraph 490 // execution, if necessary. 491 virtual StepStats* mutable_step_stats() = 0; 492 virtual CostGraphDef* mutable_cost_graph() = 0; 493 virtual size_t num_partition_graphs() const = 0; 494 virtual GraphDef* mutable_partition_graph(size_t i) = 0; 495 virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; 496 497 // Returned status if requested. 498 virtual errors::Code status_code() const = 0; 499 virtual const string& status_error_message() const = 0; 500 virtual void set_status(const Status& status) = 0; 501 502 protected: 503 // Returns a mutable protobuf message that represents the contents of 504 // this wrapper, for passing to an RPC subsystem that will populate 505 // the message. 506 // 507 // NOTE: Only `WorkerInterface` subclasses may call this method. The 508 // `InMemoryRunGraphResponse` subclass does not implement this 509 // method, and attempts to call it will fail with a fatal 510 // error. However, as long as callers always call 511 // `WorkerInterface::RunGraphAsync()` with a wrapper object returned 512 // from `WorkerInterface::CreateRunGraphResponse()` called on the 513 // *same* WorkerInterface object, this error will never trigger. 514 virtual RunGraphResponse* get_proto() = 0; 515 friend class WorkerInterface; 516 }; 517 518 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { 519 public: 520 // MutableRunGraphResponseWrapper methods. 521 size_t num_recvs() const override; 522 const string& recv_key(size_t i) const override; 523 Status RecvValue(size_t i, TensorProto* out_tensor) override; 524 Status RecvValue(size_t i, Tensor* out_tensor) override; 525 void AddRecv(const string& key, const Tensor& value) override; 526 StepStats* mutable_step_stats() override; 527 CostGraphDef* mutable_cost_graph() override; 528 size_t num_partition_graphs() const override; 529 GraphDef* mutable_partition_graph(size_t i) override; 530 void AddPartitionGraph(const GraphDef& partition_graph) override; 531 errors::Code status_code() const override; 532 const string& status_error_message() const override; 533 void set_status(const Status& status) override; 534 535 protected: 536 // NOTE: This method is not implemented. See 537 // MutableRunGraphResponseWrapper for an explanation. 538 RunGraphResponse* get_proto() override; 539 540 private: 541 gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_; 542 StepStats step_stats_; 543 CostGraphDef cost_graph_; 544 std::vector<GraphDef> partition_graphs_; 545 // Store the code and message separately so that they can be updated 546 // independently by setters. 547 Status status_; 548 }; 549 550 // Proto-based message wrapper for use on the client side of the RunGraph RPC. 551 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 552 public: 553 // MutableRunGraphResponseWrapper methods. 554 size_t num_recvs() const override; 555 const string& recv_key(size_t i) const override; 556 Status RecvValue(size_t i, TensorProto* out_tensor) override; 557 Status RecvValue(size_t i, Tensor* out_tensor) override; 558 void AddRecv(const string& key, const Tensor& value) override; 559 StepStats* mutable_step_stats() override; 560 CostGraphDef* mutable_cost_graph() override; 561 size_t num_partition_graphs() const override; 562 GraphDef* mutable_partition_graph(size_t i) override; 563 void AddPartitionGraph(const GraphDef& partition_graph) override; 564 errors::Code status_code() const override; 565 const string& status_error_message() const override; 566 void set_status(const Status& status) override; 567 568 protected: 569 RunGraphResponse* get_proto() override; 570 571 private: 572 RunGraphResponse response_; 573 }; 574 575 // Proto-based message wrapper for use on the server side of the RunGraph RPC. 576 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 577 public: 578 NonOwnedProtoRunGraphResponse(RunGraphResponse* response); 579 580 // MutableRunGraphResponseWrapper methods. 581 size_t num_recvs() const override; 582 const string& recv_key(size_t i) const override; 583 Status RecvValue(size_t i, TensorProto* out_tensor) override; 584 Status RecvValue(size_t i, Tensor* out_tensor) override; 585 void AddRecv(const string& key, const Tensor& value) override; 586 StepStats* mutable_step_stats() override; 587 CostGraphDef* mutable_cost_graph() override; 588 size_t num_partition_graphs() const override; 589 GraphDef* mutable_partition_graph(size_t i) override; 590 void AddPartitionGraph(const GraphDef& partition_graph) override; 591 errors::Code status_code() const override; 592 const string& status_error_message() const override; 593 void set_status(const Status& status) override; 594 595 protected: 596 RunGraphResponse* get_proto() override; 597 598 private: 599 RunGraphResponse* const response_; 600 }; 601 602 //////////////////////////////////////////////////////////////////////////////// 603 // 604 // Wrapper classes for the `MasterService.RunStep` response message. 605 // 606 // The `RunStepResponse` message can contain potentially large tensor 607 // data as part of its `tensor` submessages. Here we provide specialized 608 // wrappers that avoid copying the tensor data wherever possible. 609 // 610 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the 611 // protocol buffer definition. 612 // 613 //////////////////////////////////////////////////////////////////////////////// 614 615 // Abstract interface for a mutable RunStepResponse message. 616 // 617 // Note that there is no corresponding (immutable) 618 // RunStepResponseWrapper class, because the RunStepResponse object is 619 // always used as a mutable pointer. 620 class MutableRunStepResponseWrapper { 621 public: 622 virtual ~MutableRunStepResponseWrapper(); 623 624 // The values of the tensors whose fetching was requested in the 625 // RunStep call. 626 // 627 // NOTE: The order of the returned tensors may or may not match 628 // the fetch order specified in RunStepRequest. 629 virtual size_t num_tensors() const = 0; 630 virtual const string& tensor_name(size_t i) const = 0; 631 virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0; 632 633 // Stores the i^{th} recv value in `run_graph_response` in this 634 // response with the given `name`. 635 virtual Status AddTensorFromRunGraphResponse( 636 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 637 size_t i) = 0; 638 639 // Returned metadata if requested in the options. 640 virtual const RunMetadata& metadata() const = 0; 641 virtual RunMetadata* mutable_metadata() = 0; 642 643 // Returned status if requested. 644 virtual errors::Code status_code() const = 0; 645 virtual const string& status_error_message() const = 0; 646 virtual void set_status(const Status& status) = 0; 647 648 protected: 649 // Returns a mutable protobuf message that represents the contents of 650 // this wrapper, for passing to an RPC subsystem that will populate 651 // the message. 652 // 653 // NOTE: Only `MasterInterface` subclasses may call this method. The 654 // `InMemoryRunStepResponse` subclass does not implement this 655 // method, and attempts to call it will fail with a fatal 656 // error. However, as long as callers always call 657 // `MasterInterface::RunStep()` with a wrapper object returned 658 // from `MasterInterface::CreateRunStepResponse()` called on the 659 // *same* MasterInterface object, this error will never trigger. 660 virtual RunStepResponse* get_proto() = 0; 661 friend class MasterInterface; 662 }; 663 664 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { 665 public: 666 // MutableRunStepResponseWrapper methods. 667 size_t num_tensors() const override; 668 const string& tensor_name(size_t i) const override; 669 Status TensorValue(size_t i, Tensor* out_tensor) const override; 670 Status AddTensorFromRunGraphResponse( 671 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 672 size_t i) override; 673 const RunMetadata& metadata() const override; 674 RunMetadata* mutable_metadata() override; 675 errors::Code status_code() const override; 676 const string& status_error_message() const override; 677 void set_status(const Status& status) override; 678 679 protected: 680 // NOTE: This method is not implemented. See 681 // MutableRunGraphResponseWrapper for an explanation. 682 RunStepResponse* get_proto() override; 683 684 private: 685 gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_; 686 RunMetadata metadata_; 687 // Store the code and message separately so that they can be updated 688 // independently by setters. 689 Status status_; 690 }; 691 692 // Proto-based message wrapper for use on the client side of the RunStep RPC. 693 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 694 public: 695 // MutableRunStepResponseWrapper methods. 696 size_t num_tensors() const override; 697 const string& tensor_name(size_t i) const override; 698 Status TensorValue(size_t i, Tensor* out_tensor) const override; 699 Status AddTensorFromRunGraphResponse( 700 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 701 size_t i) override; 702 const RunMetadata& metadata() const override; 703 RunMetadata* mutable_metadata() override; 704 errors::Code status_code() const override; 705 const string& status_error_message() const override; 706 void set_status(const Status& status) override; 707 708 protected: 709 RunStepResponse* get_proto() override; 710 711 private: 712 RunStepResponse response_; 713 }; 714 715 // Proto-based message wrapper for use on the server side of the RunStep RPC. 716 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 717 public: 718 NonOwnedProtoRunStepResponse(RunStepResponse* response); 719 720 // MutableRunStepResponseWrapper methods. 721 size_t num_tensors() const override; 722 const string& tensor_name(size_t i) const override; 723 Status TensorValue(size_t i, Tensor* out_tensor) const override; 724 Status AddTensorFromRunGraphResponse( 725 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 726 size_t i) override; 727 const RunMetadata& metadata() const override; 728 RunMetadata* mutable_metadata() override; 729 errors::Code status_code() const override; 730 const string& status_error_message() const override; 731 void set_status(const Status& status) override; 732 733 protected: 734 RunStepResponse* get_proto() override; 735 736 private: 737 RunStepResponse* response_; // Not owned. 738 }; 739 740 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto, 741 Tensor* out_tensor); 742 743 } // namespace tensorflow 744 745 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 746