1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 18 19 #include "tensorflow/core/framework/allocator.h" 20 #include "tensorflow/core/framework/cost_graph.pb.h" 21 #include "tensorflow/core/framework/graph.pb.h" 22 #include "tensorflow/core/framework/step_stats.pb.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor.pb_text.h" 25 #include "tensorflow/core/framework/versions.pb.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 #include "tensorflow/core/protobuf/master.pb.h" 28 #include "tensorflow/core/protobuf/worker.pb.h" 29 30 namespace tensorflow { 31 32 //////////////////////////////////////////////////////////////////////////////// 33 // 34 // Wrapper classes for the `MasterService.RunStep` request message. 35 // 36 // The `RunStepRequest` message can contain potentially large tensor 37 // data as part of its `feed` submessages. Here we provide specialized 38 // wrappers that avoid copying the tensor data wherever possible. 39 // 40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the 41 // protocol buffer definition. 42 // 43 //////////////////////////////////////////////////////////////////////////////// 44 45 // Abstract interface for an immutable RunStepRequest message. 46 // 47 // This interface is typically used by server-side components in the 48 // TensorFlow master. 49 class RunStepRequestWrapper { 50 public: ~RunStepRequestWrapper()51 virtual ~RunStepRequestWrapper() {} 52 53 // REQUIRED: session_handle must be returned by a CreateSession call 54 // to the same master service. 55 virtual const string& session_handle() const = 0; 56 57 // Partial run handle (optional). If specified, this will be a partial run 58 // execution, run up to the specified fetches. 59 virtual const string& partial_run_handle() const = 0; 60 61 // Tensors to be fed in the step. Each feed is a named tensor. 62 virtual size_t num_feeds() const = 0; 63 virtual const string& feed_name(size_t i) const = 0; 64 65 // Stores the content of the feed value at index `i` in `tensor`. 66 virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0; 67 virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; 68 69 // Fetches. A list of tensor names. The caller expects a tensor to 70 // be returned for each fetch[i] (see RunStepResponse.tensor). The 71 // order of specified fetches does not change the execution order. 72 virtual size_t num_fetches() const = 0; 73 virtual const string& fetch_name(size_t i) const = 0; 74 75 // Target Nodes. A list of node names. The named nodes will be run 76 // to but their outputs will not be fetched. 77 virtual size_t num_targets() const = 0; 78 virtual const string& target_name(size_t i) const = 0; 79 80 // Options for the run call. 81 virtual const RunOptions& options() const = 0; 82 83 // If true then some errors, e.g., execution errors that have long 84 // error messages, may return an OK RunStepResponse with the actual 85 // error saved in the status_code/status_error_message fields of the 86 // response body. This is a workaround since the RPC subsystem may 87 // truncate long metadata messages. 88 virtual bool store_errors_in_response_body() const = 0; 89 90 virtual int64 request_id() const = 0; 91 92 // Returns a human-readable representation of this message for debugging. 93 virtual string DebugString() const = 0; 94 95 // Returns the wrapped data as a protocol buffer message. 96 virtual const RunStepRequest& ToProto() const = 0; 97 }; 98 99 // Abstract interface for a mutable RunStepRequest message. 100 // 101 // See `RunStepRequestWrapper` above for a description of the fields. 102 class MutableRunStepRequestWrapper : public RunStepRequestWrapper { 103 public: 104 virtual void set_session_handle(const string& handle) = 0; 105 virtual void set_partial_run_handle(const string& handle) = 0; 106 virtual void add_feed(const string& name, const Tensor& value) = 0; 107 virtual void add_fetch(const string& name) = 0; 108 virtual void add_target(const string& name) = 0; 109 virtual RunOptions* mutable_options() = 0; 110 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 111 }; 112 113 // Specialized (and mutable) wrapper for RunStep requests between a client and 114 // master in the same address space. 115 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { 116 public: 117 // RunStepRequestWrapper methods. 118 const string& session_handle() const override; 119 const string& partial_run_handle() const override; 120 size_t num_feeds() const override; 121 const string& feed_name(size_t i) const override; 122 Status FeedValue(size_t i, Tensor* out_tensor) const override; 123 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 124 size_t num_fetches() const override; 125 const string& fetch_name(size_t i) const override; 126 size_t num_targets() const override; 127 const string& target_name(size_t i) const override; 128 const RunOptions& options() const override; 129 string DebugString() const override; 130 const RunStepRequest& ToProto() const override; 131 bool store_errors_in_response_body() const override; 132 int64 request_id() const override; 133 134 // MutableRunStepRequestWrapper methods. 135 void set_session_handle(const string& handle) override; 136 void set_partial_run_handle(const string& handle) override; 137 void add_feed(const string& name, const Tensor& value) override; 138 void add_fetch(const string& name) override; 139 void add_target(const string& name) override; 140 RunOptions* mutable_options() override; 141 void set_store_errors_in_response_body(bool store_errors) override; 142 143 private: 144 string session_handle_; 145 string partial_run_handle_; 146 gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_; 147 gtl::InlinedVector<string, 4> fetches_; 148 gtl::InlinedVector<string, 4> targets_; 149 RunOptions options_; 150 bool store_errors_in_response_body_ = false; 151 152 // Holds a cached and owned representation of the proto 153 // representation of this request, if needed, so that `ToProto()` 154 // can return a const RunStepRequest&. 155 // NOTE(mrry): Although calls to `ToProto()` on this class are 156 // expected to be rare, retaining ownership of the returned message 157 // makes it easier to return a reference from the proto-backed 158 // representations. 159 mutable std::unique_ptr<RunStepRequest> proto_version_; 160 }; 161 162 // Wrapper for mutable RunStep requests that uses a protobuf message. 163 // 164 // This wrapper class should be used for RunStep requests between a 165 // client and master in different address spaces. 166 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { 167 public: 168 // RunStepRequestWrapper methods. 169 const string& session_handle() const override; 170 const string& partial_run_handle() const override; 171 size_t num_feeds() const override; 172 const string& feed_name(size_t i) const override; 173 Status FeedValue(size_t i, Tensor* out_tensor) const override; 174 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 175 size_t num_fetches() const override; 176 const string& fetch_name(size_t i) const override; 177 size_t num_targets() const override; 178 const string& target_name(size_t i) const override; 179 const RunOptions& options() const override; 180 string DebugString() const override; 181 const RunStepRequest& ToProto() const override; 182 bool store_errors_in_response_body() const override; 183 int64 request_id() const override; 184 185 // MutableRunStepRequestWrapper methods. 186 void set_session_handle(const string& handle) override; 187 void set_partial_run_handle(const string& handle) override; 188 void add_feed(const string& name, const Tensor& value) override; 189 void add_fetch(const string& name) override; 190 void add_target(const string& name) override; 191 RunOptions* mutable_options() override; 192 void set_store_errors_in_response_body(bool store_errors) override; 193 194 private: 195 RunStepRequest request_; 196 friend class MasterInterface; 197 }; 198 199 // Wrapper for immutable RunStep requests that use a non-owned 200 // protobuf message. 201 // 202 // This interface is typically used by server-side components in the 203 // TensorFlow master, where the incoming message is a (possibly const) 204 // `RunStepRequest*`. 205 class ProtoRunStepRequest : public RunStepRequestWrapper { 206 public: 207 ProtoRunStepRequest(const RunStepRequest* request); 208 209 // RunStepRequestWrapper methods. 210 const string& session_handle() const override; 211 const string& partial_run_handle() const override; 212 size_t num_feeds() const override; 213 const string& feed_name(size_t i) const override; 214 Status FeedValue(size_t i, Tensor* out_tensor) const override; 215 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 216 size_t num_fetches() const override; 217 const string& fetch_name(size_t i) const override; 218 size_t num_targets() const override; 219 const string& target_name(size_t i) const override; 220 const RunOptions& options() const override; 221 string DebugString() const override; 222 const RunStepRequest& ToProto() const override; 223 bool store_errors_in_response_body() const override; 224 int64 request_id() const override; 225 226 private: 227 const RunStepRequest* const request_; // Not owned. 228 }; 229 230 //////////////////////////////////////////////////////////////////////////////// 231 // 232 // Wrapper classes for the `WorkerService.RunGraph` request message. 233 // 234 // The `RunGraphRequest` message can contain potentially large tensor 235 // data as part of its `send` submessages. Here we provide specialized 236 // wrappers that avoid copying the tensor data wherever possible. 237 // 238 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the 239 // protocol buffer definition. 240 // 241 //////////////////////////////////////////////////////////////////////////////// 242 243 // Abstract interface for an immutable RunGraphRequest message. 244 // 245 // This interface is typically used by server-side components in the 246 // TensorFlow worker. 247 class RunGraphRequestWrapper { 248 public: ~RunGraphRequestWrapper()249 virtual ~RunGraphRequestWrapper() {} 250 251 // The session handle used to register the graph. If empty, a single global 252 // namespace is used. 253 virtual const string& session_handle() const = 0; 254 255 // Set to true if `CreateWorkerSession` was called for `session_handle`. 256 virtual bool create_worker_session_called() const = 0; 257 258 // REQUIRED: graph_handle must be returned by a RegisterGraph call 259 // to the same WorkerService. 260 virtual const string& graph_handle() const = 0; 261 262 // A unique ID to distinguish different runs of the same graph. 263 // 264 // The master generates a global unique `step_id` to distinguish 265 // different runs of the graph computation. Subgraphs communicate 266 // (e.g., send/recv ops) with each other using `step_id` to 267 // distinguish tensors generated by different runs. 268 virtual int64 step_id() const = 0; 269 270 // Options for this step. 271 virtual const ExecutorOpts& exec_opts() const = 0; 272 273 // Sends the tensors in "send" into the graph before the run. 274 virtual size_t num_sends() const = 0; 275 virtual const string& send_key(size_t i) const = 0; 276 virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0; 277 278 // Fetches the keys into `RunGraphResponse.recv` after the run. 279 virtual size_t num_recvs() const = 0; 280 virtual const string& recv_key(size_t i) const = 0; 281 282 // True if the RunGraphRequest is a partial run request. 283 virtual bool is_partial() const = 0; 284 285 // True if this is the last partial run request in a sequence of requests. 286 virtual bool is_last_partial_run() const = 0; 287 288 // If true then some errors, e.g., execution errors that have long 289 // error messages, may return an OK RunStepResponse with the actual 290 // error saved in the status_code/status_error_message fields of the 291 // response body. This is a workaround since the RPC subsystem may 292 // truncate long metadata messages. 293 virtual bool store_errors_in_response_body() const = 0; 294 295 // Returns the wrapped data as a protocol buffer message. 296 virtual const RunGraphRequest& ToProto() const = 0; 297 }; 298 299 // Abstract interface for a mutable RunGraphRequest message. 300 // 301 // See `RunGraphRequestWrapper` above for a description of the fields. 302 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { 303 public: 304 virtual void set_session_handle(const string& handle) = 0; 305 virtual void set_create_worker_session_called(bool called) = 0; 306 virtual void set_graph_handle(const string& handle) = 0; 307 virtual void set_step_id(int64 step_id) = 0; 308 virtual ExecutorOpts* mutable_exec_opts() = 0; 309 310 // Stores the i^{th} feed value in `run_step_request` in this 311 // request with the given `send_key`. 312 virtual Status AddSendFromRunStepRequest( 313 const RunStepRequestWrapper& run_step_request, size_t i, 314 const string& send_key) = 0; 315 virtual Status AddSendFromRunCallableRequest( 316 const RunCallableRequest& run_callable_request, size_t i, 317 const string& send_key) = 0; 318 319 virtual void add_recv_key(const string& recv_key) = 0; 320 virtual void set_is_partial(bool is_partial) = 0; 321 virtual void set_is_last_partial_run(bool is_last_partial_run) = 0; 322 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 323 }; 324 325 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { 326 public: 327 // RunGraphRequestWrapper methods. 328 const string& session_handle() const override; 329 const string& graph_handle() const override; 330 bool create_worker_session_called() const override; 331 int64 step_id() const override; 332 const ExecutorOpts& exec_opts() const override; 333 size_t num_sends() const override; 334 const string& send_key(size_t i) const override; 335 Status SendValue(size_t i, Tensor* out_tensor) const override; 336 size_t num_recvs() const override; 337 const string& recv_key(size_t i) const override; 338 bool is_partial() const override; 339 bool is_last_partial_run() const override; 340 const RunGraphRequest& ToProto() const override; 341 bool store_errors_in_response_body() const override; 342 343 // MutableRunGraphRequestWrapper methods. 344 void set_session_handle(const string& handle) override; 345 void set_create_worker_session_called(bool called) override; 346 void set_graph_handle(const string& handle) override; 347 void set_step_id(int64 step_id) override; 348 ExecutorOpts* mutable_exec_opts() override; 349 Status AddSendFromRunStepRequest( 350 const RunStepRequestWrapper& run_step_request, size_t i, 351 const string& send_key) override; 352 Status AddSendFromRunCallableRequest( 353 const RunCallableRequest& run_callable_request, size_t i, 354 const string& send_key) override; 355 void add_recv_key(const string& recv_key) override; 356 void set_is_partial(bool is_partial) override; 357 void set_is_last_partial_run(bool is_last_partial_run) override; 358 void set_store_errors_in_response_body(bool store_errors) override; 359 360 private: 361 string session_handle_; 362 bool create_worker_session_called_ = false; 363 string graph_handle_; 364 int64 step_id_; 365 ExecutorOpts exec_opts_; 366 gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_; 367 gtl::InlinedVector<string, 4> recvs_; 368 bool is_partial_ = false; 369 bool is_last_partial_run_ = false; 370 bool store_errors_in_response_body_ = false; 371 372 // Holds a cached and owned representation of the proto 373 // representation of this request, if needed, so that `ToProto()` 374 // can return a const RunGraphRequest&. 375 // NOTE(mrry): Although calls to `ToProto()` on this class are 376 // expected to be rare, retaining ownership of the returned message 377 // makes it easier to return a reference from the proto-backed 378 // representations. 379 mutable std::unique_ptr<RunGraphRequest> proto_version_; 380 }; 381 382 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { 383 public: 384 // RunGraphRequestWrapper methods. 385 const string& session_handle() const override; 386 bool create_worker_session_called() const override; 387 const string& graph_handle() const override; 388 int64 step_id() const override; 389 const ExecutorOpts& exec_opts() const override; 390 size_t num_sends() const override; 391 const string& send_key(size_t i) const override; 392 Status SendValue(size_t i, Tensor* out_tensor) const override; 393 size_t num_recvs() const override; 394 const string& recv_key(size_t i) const override; 395 bool is_partial() const override; 396 bool is_last_partial_run() const override; 397 bool store_errors_in_response_body() const override; 398 const RunGraphRequest& ToProto() const override; 399 400 // MutableRunGraphRequestWrapper methods. 401 void set_session_handle(const string& handle) override; 402 void set_create_worker_session_called(bool called) override; 403 void set_graph_handle(const string& handle) override; 404 void set_step_id(int64 step_id) override; 405 ExecutorOpts* mutable_exec_opts() override; 406 Status AddSendFromRunStepRequest( 407 const RunStepRequestWrapper& run_step_request, size_t i, 408 const string& send_key) override; 409 Status AddSendFromRunCallableRequest( 410 const RunCallableRequest& run_callable_request, size_t i, 411 const string& send_key) override; 412 void add_recv_key(const string& recv_key) override; 413 void set_is_partial(bool is_partial) override; 414 void set_is_last_partial_run(bool is_last_partial_run) override; 415 void set_store_errors_in_response_body(bool store_errors) override; 416 417 private: 418 RunGraphRequest request_; 419 }; 420 421 class ProtoRunGraphRequest : public RunGraphRequestWrapper { 422 public: 423 ProtoRunGraphRequest(const RunGraphRequest* request); 424 425 // RunGraphRequestWrapper methods. 426 const string& session_handle() const override; 427 bool create_worker_session_called() const override; 428 const string& graph_handle() const override; 429 int64 step_id() const override; 430 const ExecutorOpts& exec_opts() const override; 431 size_t num_sends() const override; 432 const string& send_key(size_t i) const override; 433 Status SendValue(size_t i, Tensor* out_tensor) const override; 434 size_t num_recvs() const override; 435 const string& recv_key(size_t i) const override; 436 bool is_partial() const override; 437 bool is_last_partial_run() const override; 438 bool store_errors_in_response_body() const override; 439 const RunGraphRequest& ToProto() const override; 440 441 private: 442 const RunGraphRequest* const request_; // Not owned. 443 }; 444 445 //////////////////////////////////////////////////////////////////////////////// 446 // 447 // Wrapper classes for the `WorkerService.RunGraph` response message. 448 // 449 // The `RunGraphResponse` message can contain potentially large tensor 450 // data as part of its `recv` submessages. Here we provide specialized 451 // wrappers that avoid copying the tensor data wherever possible. 452 // 453 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the 454 // protocol buffer definition. 455 // 456 //////////////////////////////////////////////////////////////////////////////// 457 458 // Abstract interface for a mutable RunGraphResponse message. 459 // 460 // Note that there is no corresponding (immutable) 461 // RunGraphResponseWrapper class, because the RunGraphResponse object 462 // is always used as a mutable pointer. 463 class MutableRunGraphResponseWrapper { 464 public: ~MutableRunGraphResponseWrapper()465 virtual ~MutableRunGraphResponseWrapper() {} 466 467 // A list of tensors corresponding to those requested by 468 // `RunGraphRequest.recv_key`. 469 virtual size_t num_recvs() const = 0; 470 virtual const string& recv_key(size_t i) const = 0; 471 // NOTE: The following methods may perform a destructive read, for 472 // efficiency. 473 virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0; 474 virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0; 475 virtual void AddRecv(const string& key, const Tensor& value) = 0; 476 477 // Submessages that store performance statistics about the subgraph 478 // execution, if necessary. 479 virtual StepStats* mutable_step_stats() = 0; 480 virtual CostGraphDef* mutable_cost_graph() = 0; 481 virtual size_t num_partition_graphs() const = 0; 482 virtual GraphDef* mutable_partition_graph(size_t i) = 0; 483 virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; 484 485 // Returned status if requested. 486 virtual errors::Code status_code() const = 0; 487 virtual const string& status_error_message() const = 0; 488 virtual void set_status(const Status& status) = 0; 489 490 protected: 491 // Returns a mutable protobuf message that represents the contents of 492 // this wrapper, for passing to an RPC subsystem that will populate 493 // the message. 494 // 495 // NOTE: Only `WorkerInterface` subclasses may call this method. The 496 // `InMemoryRunGraphResponse` subclass does not implement this 497 // method, and attempts to call it will fail with a fatal 498 // error. However, as long as callers always call 499 // `WorkerInterface::RunGraphAsync()` with a wrapper object returned 500 // from `WorkerInterface::CreateRunGraphResponse()` called on the 501 // *same* WorkerInterface object, this error will never trigger. 502 virtual RunGraphResponse* get_proto() = 0; 503 friend class WorkerInterface; 504 }; 505 506 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { 507 public: 508 // MutableRunGraphResponseWrapper methods. 509 size_t num_recvs() const override; 510 const string& recv_key(size_t i) const override; 511 Status RecvValue(size_t i, TensorProto* out_tensor) override; 512 Status RecvValue(size_t i, Tensor* out_tensor) override; 513 void AddRecv(const string& key, const Tensor& value) override; 514 StepStats* mutable_step_stats() override; 515 CostGraphDef* mutable_cost_graph() override; 516 size_t num_partition_graphs() const override; 517 GraphDef* mutable_partition_graph(size_t i) override; 518 void AddPartitionGraph(const GraphDef& partition_graph) override; 519 errors::Code status_code() const override; 520 const string& status_error_message() const override; 521 void set_status(const Status& status) override; 522 523 protected: 524 // NOTE: This method is not implemented. See 525 // MutableRunGraphResponseWrapper for an explanation. 526 RunGraphResponse* get_proto() override; 527 528 private: 529 gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_; 530 StepStats step_stats_; 531 CostGraphDef cost_graph_; 532 std::vector<GraphDef> partition_graphs_; 533 // Store the code and message separately so that they can be updated 534 // independently by setters. 535 Status status_; 536 }; 537 538 // Proto-based message wrapper for use on the client side of the RunGraph RPC. 539 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 540 public: 541 // MutableRunGraphResponseWrapper methods. 542 size_t num_recvs() const override; 543 const string& recv_key(size_t i) const override; 544 Status RecvValue(size_t i, TensorProto* out_tensor) override; 545 Status RecvValue(size_t i, Tensor* out_tensor) override; 546 void AddRecv(const string& key, const Tensor& value) override; 547 StepStats* mutable_step_stats() override; 548 CostGraphDef* mutable_cost_graph() override; 549 size_t num_partition_graphs() const override; 550 GraphDef* mutable_partition_graph(size_t i) override; 551 void AddPartitionGraph(const GraphDef& partition_graph) override; 552 errors::Code status_code() const override; 553 const string& status_error_message() const override; 554 void set_status(const Status& status) override; 555 556 protected: 557 RunGraphResponse* get_proto() override; 558 559 private: 560 RunGraphResponse response_; 561 }; 562 563 // Proto-based message wrapper for use on the server side of the RunGraph RPC. 564 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 565 public: 566 NonOwnedProtoRunGraphResponse(RunGraphResponse* response); 567 568 // MutableRunGraphResponseWrapper methods. 569 size_t num_recvs() const override; 570 const string& recv_key(size_t i) const override; 571 Status RecvValue(size_t i, TensorProto* out_tensor) override; 572 Status RecvValue(size_t i, Tensor* out_tensor) override; 573 void AddRecv(const string& key, const Tensor& value) override; 574 StepStats* mutable_step_stats() override; 575 CostGraphDef* mutable_cost_graph() override; 576 size_t num_partition_graphs() const override; 577 GraphDef* mutable_partition_graph(size_t i) override; 578 void AddPartitionGraph(const GraphDef& partition_graph) override; 579 errors::Code status_code() const override; 580 const string& status_error_message() const override; 581 void set_status(const Status& status) override; 582 583 protected: 584 RunGraphResponse* get_proto() override; 585 586 private: 587 RunGraphResponse* const response_; 588 }; 589 590 //////////////////////////////////////////////////////////////////////////////// 591 // 592 // Wrapper classes for the `MasterService.RunStep` response message. 593 // 594 // The `RunStepResponse` message can contain potentially large tensor 595 // data as part of its `tensor` submessages. Here we provide specialized 596 // wrappers that avoid copying the tensor data wherever possible. 597 // 598 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the 599 // protocol buffer definition. 600 // 601 //////////////////////////////////////////////////////////////////////////////// 602 603 // Abstract interface for a mutable RunStepResponse message. 604 // 605 // Note that there is no corresponding (immutable) 606 // RunStepResponseWrapper class, because the RunStepResponse object is 607 // always used as a mutable pointer. 608 class MutableRunStepResponseWrapper { 609 public: 610 virtual ~MutableRunStepResponseWrapper(); 611 612 // The values of the tensors whose fetching was requested in the 613 // RunStep call. 614 // 615 // NOTE: The order of the returned tensors may or may not match 616 // the fetch order specified in RunStepRequest. 617 virtual size_t num_tensors() const = 0; 618 virtual const string& tensor_name(size_t i) const = 0; 619 virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0; 620 621 // Stores the i^{th} recv value in `run_graph_response` in this 622 // response with the given `name`. 623 virtual Status AddTensorFromRunGraphResponse( 624 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 625 size_t i) = 0; 626 627 // Returned metadata if requested in the options. 628 virtual const RunMetadata& metadata() const = 0; 629 virtual RunMetadata* mutable_metadata() = 0; 630 631 // Returned status if requested. 632 virtual errors::Code status_code() const = 0; 633 virtual const string& status_error_message() const = 0; 634 virtual void set_status(const Status& status) = 0; 635 636 protected: 637 // Returns a mutable protobuf message that represents the contents of 638 // this wrapper, for passing to an RPC subsystem that will populate 639 // the message. 640 // 641 // NOTE: Only `MasterInterface` subclasses may call this method. The 642 // `InMemoryRunStepResponse` subclass does not implement this 643 // method, and attempts to call it will fail with a fatal 644 // error. However, as long as callers always call 645 // `MasterInterface::RunStep()` with a wrapper object returned 646 // from `MasterInterface::CreateRunStepResponse()` called on the 647 // *same* MasterInterface object, this error will never trigger. 648 virtual RunStepResponse* get_proto() = 0; 649 friend class MasterInterface; 650 }; 651 652 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { 653 public: 654 // MutableRunStepResponseWrapper methods. 655 size_t num_tensors() const override; 656 const string& tensor_name(size_t i) const override; 657 Status TensorValue(size_t i, Tensor* out_tensor) const override; 658 Status AddTensorFromRunGraphResponse( 659 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 660 size_t i) override; 661 const RunMetadata& metadata() const override; 662 RunMetadata* mutable_metadata() override; 663 errors::Code status_code() const override; 664 const string& status_error_message() const override; 665 void set_status(const Status& status) override; 666 667 protected: 668 // NOTE: This method is not implemented. See 669 // MutableRunGraphResponseWrapper for an explanation. 670 RunStepResponse* get_proto() override; 671 672 private: 673 gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_; 674 RunMetadata metadata_; 675 // Store the code and message separately so that they can be updated 676 // independently by setters. 677 Status status_; 678 }; 679 680 // Proto-based message wrapper for use on the client side of the RunStep RPC. 681 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 682 public: 683 // MutableRunStepResponseWrapper methods. 684 size_t num_tensors() const override; 685 const string& tensor_name(size_t i) const override; 686 Status TensorValue(size_t i, Tensor* out_tensor) const override; 687 Status AddTensorFromRunGraphResponse( 688 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 689 size_t i) override; 690 const RunMetadata& metadata() const override; 691 RunMetadata* mutable_metadata() override; 692 errors::Code status_code() const override; 693 const string& status_error_message() const override; 694 void set_status(const Status& status) override; 695 696 protected: 697 RunStepResponse* get_proto() override; 698 699 private: 700 RunStepResponse response_; 701 }; 702 703 // Proto-based message wrapper for use on the server side of the RunStep RPC. 704 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 705 public: 706 NonOwnedProtoRunStepResponse(RunStepResponse* response); 707 708 // MutableRunStepResponseWrapper methods. 709 size_t num_tensors() const override; 710 const string& tensor_name(size_t i) const override; 711 Status TensorValue(size_t i, Tensor* out_tensor) const override; 712 Status AddTensorFromRunGraphResponse( 713 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 714 size_t i) override; 715 const RunMetadata& metadata() const override; 716 RunMetadata* mutable_metadata() override; 717 errors::Code status_code() const override; 718 const string& status_error_message() const override; 719 void set_status(const Status& status) override; 720 721 protected: 722 RunStepResponse* get_proto() override; 723 724 private: 725 RunStepResponse* response_; // Not owned. 726 }; 727 728 } // namespace tensorflow 729 730 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 731