• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
18 
19 #include "tensorflow/core/framework/allocator.h"
20 #include "tensorflow/core/framework/cost_graph.pb.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/step_stats.pb.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/protobuf/master.pb.h"
28 #include "tensorflow/core/protobuf/worker.pb.h"
29 
30 namespace tensorflow {
31 
32 ////////////////////////////////////////////////////////////////////////////////
33 //
34 // Wrapper classes for the `MasterService.RunStep` request message.
35 //
36 // The `RunStepRequest` message can contain potentially large tensor
37 // data as part of its `feed` submessages. Here we provide specialized
38 // wrappers that avoid copying the tensor data wherever possible.
39 //
40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the
41 // protocol buffer definition.
42 //
43 ////////////////////////////////////////////////////////////////////////////////
44 
45 // Abstract interface for an immutable RunStepRequest message.
46 //
47 // This interface is typically used by server-side components in the
48 // TensorFlow master.
49 class RunStepRequestWrapper {
50  public:
~RunStepRequestWrapper()51   virtual ~RunStepRequestWrapper() {}
52 
53   // REQUIRED: session_handle must be returned by a CreateSession call
54   // to the same master service.
55   virtual const string& session_handle() const = 0;
56 
57   // Partial run handle (optional). If specified, this will be a partial run
58   // execution, run up to the specified fetches.
59   virtual const string& partial_run_handle() const = 0;
60 
61   // Tensors to be fed in the step. Each feed is a named tensor.
62   virtual size_t num_feeds() const = 0;
63   virtual const string& feed_name(size_t i) const = 0;
64 
65   // Stores the content of the feed value at index `i` in `tensor`.
66   virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0;
67   virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0;
68 
69   // Fetches. A list of tensor names. The caller expects a tensor to
70   // be returned for each fetch[i] (see RunStepResponse.tensor). The
71   // order of specified fetches does not change the execution order.
72   virtual size_t num_fetches() const = 0;
73   virtual const string& fetch_name(size_t i) const = 0;
74 
75   // Target Nodes. A list of node names. The named nodes will be run
76   // to but their outputs will not be fetched.
77   virtual size_t num_targets() const = 0;
78   virtual const string& target_name(size_t i) const = 0;
79 
80   // Options for the run call.
81   virtual const RunOptions& options() const = 0;
82 
83   // If true then some errors, e.g., execution errors that have long
84   // error messages, may return an OK RunStepResponse with the actual
85   // error saved in the status_code/status_error_message fields of the
86   // response body. This is a workaround since the RPC subsystem may
87   // truncate long metadata messages.
88   virtual bool store_errors_in_response_body() const = 0;
89 
90   // Unique identifier for this request. Every RunGraphRequest must have a
91   // unique request_id, and retried RunGraphRequests must have the same
92   // request_id. If request_id is zero, retry detection is disabled.
93   virtual int64 request_id() const = 0;
94 
95   // Returns a human-readable representation of this message for debugging.
96   virtual string DebugString() const = 0;
97 
98   // Returns the wrapped data as a protocol buffer message.
99   virtual const RunStepRequest& ToProto() const = 0;
100 };
101 
102 // Abstract interface for a mutable RunStepRequest message.
103 //
104 // See `RunStepRequestWrapper` above for a description of the fields.
105 class MutableRunStepRequestWrapper : public RunStepRequestWrapper {
106  public:
107   virtual void set_session_handle(const string& handle) = 0;
108   virtual void set_partial_run_handle(const string& handle) = 0;
109   virtual void add_feed(const string& name, const Tensor& value) = 0;
110   virtual void add_fetch(const string& name) = 0;
111   virtual void add_target(const string& name) = 0;
112   virtual RunOptions* mutable_options() = 0;
113   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
114 };
115 
116 // Specialized (and mutable) wrapper for RunStep requests between a client and
117 // master in the same address space.
118 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
119  public:
120   // RunStepRequestWrapper methods.
121   const string& session_handle() const override;
122   const string& partial_run_handle() const override;
123   size_t num_feeds() const override;
124   const string& feed_name(size_t i) const override;
125   Status FeedValue(size_t i, Tensor* out_tensor) const override;
126   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
127   size_t num_fetches() const override;
128   const string& fetch_name(size_t i) const override;
129   size_t num_targets() const override;
130   const string& target_name(size_t i) const override;
131   const RunOptions& options() const override;
132   string DebugString() const override;
133   const RunStepRequest& ToProto() const override;
134   bool store_errors_in_response_body() const override;
135   int64 request_id() const override;
136 
137   // MutableRunStepRequestWrapper methods.
138   void set_session_handle(const string& handle) override;
139   void set_partial_run_handle(const string& handle) override;
140   void add_feed(const string& name, const Tensor& value) override;
141   void add_fetch(const string& name) override;
142   void add_target(const string& name) override;
143   RunOptions* mutable_options() override;
144   void set_store_errors_in_response_body(bool store_errors) override;
145 
146  private:
147   string session_handle_;
148   string partial_run_handle_;
149   gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_;
150   gtl::InlinedVector<string, 4> fetches_;
151   gtl::InlinedVector<string, 4> targets_;
152   RunOptions options_;
153   bool store_errors_in_response_body_ = false;
154 
155   // Holds a cached and owned representation of the proto
156   // representation of this request, if needed, so that `ToProto()`
157   // can return a const RunStepRequest&.
158   // NOTE(mrry): Although calls to `ToProto()` on this class are
159   // expected to be rare, retaining ownership of the returned message
160   // makes it easier to return a reference from the proto-backed
161   // representations.
162   mutable std::unique_ptr<RunStepRequest> proto_version_;
163 };
164 
165 // Wrapper for mutable RunStep requests that uses a protobuf message.
166 //
167 // This wrapper class should be used for RunStep requests between a
168 // client and master in different address spaces.
169 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
170  public:
171   // RunStepRequestWrapper methods.
172   const string& session_handle() const override;
173   const string& partial_run_handle() const override;
174   size_t num_feeds() const override;
175   const string& feed_name(size_t i) const override;
176   Status FeedValue(size_t i, Tensor* out_tensor) const override;
177   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
178   size_t num_fetches() const override;
179   const string& fetch_name(size_t i) const override;
180   size_t num_targets() const override;
181   const string& target_name(size_t i) const override;
182   const RunOptions& options() const override;
183   string DebugString() const override;
184   const RunStepRequest& ToProto() const override;
185   bool store_errors_in_response_body() const override;
186   int64 request_id() const override;
187 
188   // MutableRunStepRequestWrapper methods.
189   void set_session_handle(const string& handle) override;
190   void set_partial_run_handle(const string& handle) override;
191   void add_feed(const string& name, const Tensor& value) override;
192   void add_fetch(const string& name) override;
193   void add_target(const string& name) override;
194   RunOptions* mutable_options() override;
195   void set_store_errors_in_response_body(bool store_errors) override;
196 
197  private:
198   RunStepRequest request_;
199   friend class MasterInterface;
200 };
201 
202 // Wrapper for immutable RunStep requests that use a non-owned
203 // protobuf message.
204 //
205 // This interface is typically used by server-side components in the
206 // TensorFlow master, where the incoming message is a (possibly const)
207 // `RunStepRequest*`.
208 class ProtoRunStepRequest : public RunStepRequestWrapper {
209  public:
210   ProtoRunStepRequest(const RunStepRequest* request);
211 
212   // RunStepRequestWrapper methods.
213   const string& session_handle() const override;
214   const string& partial_run_handle() const override;
215   size_t num_feeds() const override;
216   const string& feed_name(size_t i) const override;
217   Status FeedValue(size_t i, Tensor* out_tensor) const override;
218   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
219   size_t num_fetches() const override;
220   const string& fetch_name(size_t i) const override;
221   size_t num_targets() const override;
222   const string& target_name(size_t i) const override;
223   const RunOptions& options() const override;
224   string DebugString() const override;
225   const RunStepRequest& ToProto() const override;
226   bool store_errors_in_response_body() const override;
227   int64 request_id() const override;
228 
229  private:
230   const RunStepRequest* const request_;  // Not owned.
231 };
232 
233 ////////////////////////////////////////////////////////////////////////////////
234 //
235 // Wrapper classes for the `WorkerService.RunGraph` request message.
236 //
237 // The `RunGraphRequest` message can contain potentially large tensor
238 // data as part of its `send` submessages. Here we provide specialized
239 // wrappers that avoid copying the tensor data wherever possible.
240 //
241 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the
242 // protocol buffer definition.
243 //
244 ////////////////////////////////////////////////////////////////////////////////
245 
246 // Abstract interface for an immutable RunGraphRequest message.
247 //
248 // This interface is typically used by server-side components in the
249 // TensorFlow worker.
250 class RunGraphRequestWrapper {
251  public:
~RunGraphRequestWrapper()252   virtual ~RunGraphRequestWrapper() {}
253 
254   // The session handle used to register the graph. If empty, a single global
255   // namespace is used.
256   virtual const string& session_handle() const = 0;
257 
258   // Set to true if `CreateWorkerSession` was called for `session_handle`.
259   virtual bool create_worker_session_called() const = 0;
260 
261   // REQUIRED: graph_handle must be returned by a RegisterGraph call
262   // to the same WorkerService.
263   virtual const string& graph_handle() const = 0;
264 
265   // A unique ID to distinguish different runs of the same graph.
266   //
267   // The master generates a global unique `step_id` to distinguish
268   // different runs of the graph computation. Subgraphs communicate
269   // (e.g., send/recv ops) with each other using `step_id` to
270   // distinguish tensors generated by different runs.
271   virtual int64 step_id() const = 0;
272 
273   // Options for this step.
274   virtual const ExecutorOpts& exec_opts() const = 0;
275 
276   // Sends the tensors in "send" into the graph before the run.
277   virtual size_t num_sends() const = 0;
278   virtual const string& send_key(size_t i) const = 0;
279   virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0;
280 
281   // Fetches the keys into `RunGraphResponse.recv` after the run.
282   virtual size_t num_recvs() const = 0;
283   virtual const string& recv_key(size_t i) const = 0;
284 
285   // True if the RunGraphRequest is a partial run request.
286   virtual bool is_partial() const = 0;
287 
288   // True if this is the last partial run request in a sequence of requests.
289   virtual bool is_last_partial_run() const = 0;
290 
291   // If true then some errors, e.g., execution errors that have long
292   // error messages, may return an OK RunStepResponse with the actual
293   // error saved in the status_code/status_error_message fields of the
294   // response body. This is a workaround since the RPC subsystem may
295   // truncate long metadata messages.
296   virtual bool store_errors_in_response_body() const = 0;
297 
298   virtual int64 request_id() const = 0;
299 
300   // Returns the wrapped data as a protocol buffer message.
301   virtual const RunGraphRequest& ToProto() const = 0;
302 };
303 
304 // Abstract interface for a mutable RunGraphRequest message.
305 //
306 // See `RunGraphRequestWrapper` above for a description of the fields.
307 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
308  public:
309   virtual void set_session_handle(const string& handle) = 0;
310   virtual void set_create_worker_session_called(bool called) = 0;
311   virtual void set_graph_handle(const string& handle) = 0;
312   virtual void set_step_id(int64 step_id) = 0;
313   virtual ExecutorOpts* mutable_exec_opts() = 0;
314 
315   // Stores the i^{th} feed value in `run_step_request` in this
316   // request with the given `send_key`.
317   virtual Status AddSendFromRunStepRequest(
318       const RunStepRequestWrapper& run_step_request, size_t i,
319       const string& send_key) = 0;
320   virtual Status AddSendFromRunCallableRequest(
321       const RunCallableRequest& run_callable_request, size_t i,
322       const string& send_key) = 0;
323 
324   virtual void add_recv_key(const string& recv_key) = 0;
325   virtual void set_is_partial(bool is_partial) = 0;
326   virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
327   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
328   virtual void set_request_id(int64 request_id) = 0;
329 };
330 
331 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
332  public:
333   // RunGraphRequestWrapper methods.
334   const string& session_handle() const override;
335   const string& graph_handle() const override;
336   bool create_worker_session_called() const override;
337   int64 step_id() const override;
338   const ExecutorOpts& exec_opts() const override;
339   size_t num_sends() const override;
340   const string& send_key(size_t i) const override;
341   Status SendValue(size_t i, Tensor* out_tensor) const override;
342   size_t num_recvs() const override;
343   const string& recv_key(size_t i) const override;
344   bool is_partial() const override;
345   bool is_last_partial_run() const override;
346   const RunGraphRequest& ToProto() const override;
347   bool store_errors_in_response_body() const override;
348   int64 request_id() const override;
349 
350   // MutableRunGraphRequestWrapper methods.
351   void set_session_handle(const string& handle) override;
352   void set_create_worker_session_called(bool called) override;
353   void set_graph_handle(const string& handle) override;
354   void set_step_id(int64 step_id) override;
355   ExecutorOpts* mutable_exec_opts() override;
356   Status AddSendFromRunStepRequest(
357       const RunStepRequestWrapper& run_step_request, size_t i,
358       const string& send_key) override;
359   Status AddSendFromRunCallableRequest(
360       const RunCallableRequest& run_callable_request, size_t i,
361       const string& send_key) override;
362   void add_recv_key(const string& recv_key) override;
363   void set_is_partial(bool is_partial) override;
364   void set_is_last_partial_run(bool is_last_partial_run) override;
365   void set_store_errors_in_response_body(bool store_errors) override;
366   void set_request_id(int64 request_id) override;
367 
368  private:
369   string session_handle_;
370   bool create_worker_session_called_ = false;
371   string graph_handle_;
372   int64 step_id_;
373   ExecutorOpts exec_opts_;
374   gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_;
375   gtl::InlinedVector<string, 4> recvs_;
376   bool is_partial_ = false;
377   bool is_last_partial_run_ = false;
378   bool store_errors_in_response_body_ = false;
379   int64 request_id_ = 0;
380 
381   // Holds a cached and owned representation of the proto
382   // representation of this request, if needed, so that `ToProto()`
383   // can return a const RunGraphRequest&.
384   // NOTE(mrry): Although calls to `ToProto()` on this class are
385   // expected to be rare, retaining ownership of the returned message
386   // makes it easier to return a reference from the proto-backed
387   // representations.
388   mutable std::unique_ptr<RunGraphRequest> proto_version_;
389 };
390 
391 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
392  public:
393   // RunGraphRequestWrapper methods.
394   const string& session_handle() const override;
395   bool create_worker_session_called() const override;
396   const string& graph_handle() const override;
397   int64 step_id() const override;
398   const ExecutorOpts& exec_opts() const override;
399   size_t num_sends() const override;
400   const string& send_key(size_t i) const override;
401   Status SendValue(size_t i, Tensor* out_tensor) const override;
402   size_t num_recvs() const override;
403   const string& recv_key(size_t i) const override;
404   bool is_partial() const override;
405   bool is_last_partial_run() const override;
406   bool store_errors_in_response_body() const override;
407   int64 request_id() const override;
408   const RunGraphRequest& ToProto() const override;
409 
410   // MutableRunGraphRequestWrapper methods.
411   void set_session_handle(const string& handle) override;
412   void set_create_worker_session_called(bool called) override;
413   void set_graph_handle(const string& handle) override;
414   void set_step_id(int64 step_id) override;
415   ExecutorOpts* mutable_exec_opts() override;
416   Status AddSendFromRunStepRequest(
417       const RunStepRequestWrapper& run_step_request, size_t i,
418       const string& send_key) override;
419   Status AddSendFromRunCallableRequest(
420       const RunCallableRequest& run_callable_request, size_t i,
421       const string& send_key) override;
422   void add_recv_key(const string& recv_key) override;
423   void set_is_partial(bool is_partial) override;
424   void set_is_last_partial_run(bool is_last_partial_run) override;
425   void set_store_errors_in_response_body(bool store_errors) override;
426   void set_request_id(int64 request_id) override;
427 
428  private:
429   RunGraphRequest request_;
430 };
431 
432 class ProtoRunGraphRequest : public RunGraphRequestWrapper {
433  public:
434   ProtoRunGraphRequest(const RunGraphRequest* request);
435 
436   // RunGraphRequestWrapper methods.
437   const string& session_handle() const override;
438   bool create_worker_session_called() const override;
439   const string& graph_handle() const override;
440   int64 step_id() const override;
441   const ExecutorOpts& exec_opts() const override;
442   size_t num_sends() const override;
443   const string& send_key(size_t i) const override;
444   Status SendValue(size_t i, Tensor* out_tensor) const override;
445   size_t num_recvs() const override;
446   const string& recv_key(size_t i) const override;
447   bool is_partial() const override;
448   bool is_last_partial_run() const override;
449   bool store_errors_in_response_body() const override;
450   int64 request_id() const override;
451   const RunGraphRequest& ToProto() const override;
452 
453  private:
454   const RunGraphRequest* const request_;  // Not owned.
455 };
456 
457 ////////////////////////////////////////////////////////////////////////////////
458 //
459 // Wrapper classes for the `WorkerService.RunGraph` response message.
460 //
461 // The `RunGraphResponse` message can contain potentially large tensor
462 // data as part of its `recv` submessages. Here we provide specialized
463 // wrappers that avoid copying the tensor data wherever possible.
464 //
465 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the
466 // protocol buffer definition.
467 //
468 ////////////////////////////////////////////////////////////////////////////////
469 
470 // Abstract interface for a mutable RunGraphResponse message.
471 //
472 // Note that there is no corresponding (immutable)
473 // RunGraphResponseWrapper class, because the RunGraphResponse object
474 // is always used as a mutable pointer.
475 class MutableRunGraphResponseWrapper {
476  public:
~MutableRunGraphResponseWrapper()477   virtual ~MutableRunGraphResponseWrapper() {}
478 
479   // A list of tensors corresponding to those requested by
480   // `RunGraphRequest.recv_key`.
481   virtual size_t num_recvs() const = 0;
482   virtual const string& recv_key(size_t i) const = 0;
483   // NOTE: The following methods may perform a destructive read, for
484   // efficiency.
485   virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0;
486   virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0;
487   virtual void AddRecv(const string& key, const Tensor& value) = 0;
488 
489   // Submessages that store performance statistics about the subgraph
490   // execution, if necessary.
491   virtual StepStats* mutable_step_stats() = 0;
492   virtual CostGraphDef* mutable_cost_graph() = 0;
493   virtual size_t num_partition_graphs() const = 0;
494   virtual GraphDef* mutable_partition_graph(size_t i) = 0;
495   virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
496 
497   // Returned status if requested.
498   virtual errors::Code status_code() const = 0;
499   virtual const string& status_error_message() const = 0;
500   virtual void set_status(const Status& status) = 0;
501 
502  protected:
503   // Returns a mutable protobuf message that represents the contents of
504   // this wrapper, for passing to an RPC subsystem that will populate
505   // the message.
506   //
507   // NOTE: Only `WorkerInterface` subclasses may call this method. The
508   // `InMemoryRunGraphResponse` subclass does not implement this
509   // method, and attempts to call it will fail with a fatal
510   // error. However, as long as callers always call
511   // `WorkerInterface::RunGraphAsync()` with a wrapper object returned
512   // from `WorkerInterface::CreateRunGraphResponse()` called on the
513   // *same* WorkerInterface object, this error will never trigger.
514   virtual RunGraphResponse* get_proto() = 0;
515   friend class WorkerInterface;
516 };
517 
518 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
519  public:
520   // MutableRunGraphResponseWrapper methods.
521   size_t num_recvs() const override;
522   const string& recv_key(size_t i) const override;
523   Status RecvValue(size_t i, TensorProto* out_tensor) override;
524   Status RecvValue(size_t i, Tensor* out_tensor) override;
525   void AddRecv(const string& key, const Tensor& value) override;
526   StepStats* mutable_step_stats() override;
527   CostGraphDef* mutable_cost_graph() override;
528   size_t num_partition_graphs() const override;
529   GraphDef* mutable_partition_graph(size_t i) override;
530   void AddPartitionGraph(const GraphDef& partition_graph) override;
531   errors::Code status_code() const override;
532   const string& status_error_message() const override;
533   void set_status(const Status& status) override;
534 
535  protected:
536   // NOTE: This method is not implemented. See
537   // MutableRunGraphResponseWrapper for an explanation.
538   RunGraphResponse* get_proto() override;
539 
540  private:
541   gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
542   StepStats step_stats_;
543   CostGraphDef cost_graph_;
544   std::vector<GraphDef> partition_graphs_;
545   // Store the code and message separately so that they can be updated
546   // independently by setters.
547   Status status_;
548 };
549 
550 // Proto-based message wrapper for use on the client side of the RunGraph RPC.
551 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
552  public:
553   // MutableRunGraphResponseWrapper methods.
554   size_t num_recvs() const override;
555   const string& recv_key(size_t i) const override;
556   Status RecvValue(size_t i, TensorProto* out_tensor) override;
557   Status RecvValue(size_t i, Tensor* out_tensor) override;
558   void AddRecv(const string& key, const Tensor& value) override;
559   StepStats* mutable_step_stats() override;
560   CostGraphDef* mutable_cost_graph() override;
561   size_t num_partition_graphs() const override;
562   GraphDef* mutable_partition_graph(size_t i) override;
563   void AddPartitionGraph(const GraphDef& partition_graph) override;
564   errors::Code status_code() const override;
565   const string& status_error_message() const override;
566   void set_status(const Status& status) override;
567 
568  protected:
569   RunGraphResponse* get_proto() override;
570 
571  private:
572   RunGraphResponse response_;
573 };
574 
575 // Proto-based message wrapper for use on the server side of the RunGraph RPC.
576 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
577  public:
578   NonOwnedProtoRunGraphResponse(RunGraphResponse* response);
579 
580   // MutableRunGraphResponseWrapper methods.
581   size_t num_recvs() const override;
582   const string& recv_key(size_t i) const override;
583   Status RecvValue(size_t i, TensorProto* out_tensor) override;
584   Status RecvValue(size_t i, Tensor* out_tensor) override;
585   void AddRecv(const string& key, const Tensor& value) override;
586   StepStats* mutable_step_stats() override;
587   CostGraphDef* mutable_cost_graph() override;
588   size_t num_partition_graphs() const override;
589   GraphDef* mutable_partition_graph(size_t i) override;
590   void AddPartitionGraph(const GraphDef& partition_graph) override;
591   errors::Code status_code() const override;
592   const string& status_error_message() const override;
593   void set_status(const Status& status) override;
594 
595  protected:
596   RunGraphResponse* get_proto() override;
597 
598  private:
599   RunGraphResponse* const response_;
600 };
601 
602 ////////////////////////////////////////////////////////////////////////////////
603 //
604 // Wrapper classes for the `MasterService.RunStep` response message.
605 //
606 // The `RunStepResponse` message can contain potentially large tensor
607 // data as part of its `tensor` submessages. Here we provide specialized
608 // wrappers that avoid copying the tensor data wherever possible.
609 //
610 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the
611 // protocol buffer definition.
612 //
613 ////////////////////////////////////////////////////////////////////////////////
614 
615 // Abstract interface for a mutable RunStepResponse message.
616 //
617 // Note that there is no corresponding (immutable)
618 // RunStepResponseWrapper class, because the RunStepResponse object is
619 // always used as a mutable pointer.
620 class MutableRunStepResponseWrapper {
621  public:
622   virtual ~MutableRunStepResponseWrapper();
623 
624   // The values of the tensors whose fetching was requested in the
625   // RunStep call.
626   //
627   // NOTE: The order of the returned tensors may or may not match
628   // the fetch order specified in RunStepRequest.
629   virtual size_t num_tensors() const = 0;
630   virtual const string& tensor_name(size_t i) const = 0;
631   virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0;
632 
633   // Stores the i^{th} recv value in `run_graph_response` in this
634   // response with the given `name`.
635   virtual Status AddTensorFromRunGraphResponse(
636       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
637       size_t i) = 0;
638 
639   // Returned metadata if requested in the options.
640   virtual const RunMetadata& metadata() const = 0;
641   virtual RunMetadata* mutable_metadata() = 0;
642 
643   // Returned status if requested.
644   virtual errors::Code status_code() const = 0;
645   virtual const string& status_error_message() const = 0;
646   virtual void set_status(const Status& status) = 0;
647 
648  protected:
649   // Returns a mutable protobuf message that represents the contents of
650   // this wrapper, for passing to an RPC subsystem that will populate
651   // the message.
652   //
653   // NOTE: Only `MasterInterface` subclasses may call this method. The
654   // `InMemoryRunStepResponse` subclass does not implement this
655   // method, and attempts to call it will fail with a fatal
656   // error. However, as long as callers always call
657   // `MasterInterface::RunStep()` with a wrapper object returned
658   // from `MasterInterface::CreateRunStepResponse()` called on the
659   // *same* MasterInterface object, this error will never trigger.
660   virtual RunStepResponse* get_proto() = 0;
661   friend class MasterInterface;
662 };
663 
664 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
665  public:
666   // MutableRunStepResponseWrapper methods.
667   size_t num_tensors() const override;
668   const string& tensor_name(size_t i) const override;
669   Status TensorValue(size_t i, Tensor* out_tensor) const override;
670   Status AddTensorFromRunGraphResponse(
671       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
672       size_t i) override;
673   const RunMetadata& metadata() const override;
674   RunMetadata* mutable_metadata() override;
675   errors::Code status_code() const override;
676   const string& status_error_message() const override;
677   void set_status(const Status& status) override;
678 
679  protected:
680   // NOTE: This method is not implemented. See
681   // MutableRunGraphResponseWrapper for an explanation.
682   RunStepResponse* get_proto() override;
683 
684  private:
685   gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
686   RunMetadata metadata_;
687   // Store the code and message separately so that they can be updated
688   // independently by setters.
689   Status status_;
690 };
691 
692 // Proto-based message wrapper for use on the client side of the RunStep RPC.
693 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
694  public:
695   // MutableRunStepResponseWrapper methods.
696   size_t num_tensors() const override;
697   const string& tensor_name(size_t i) const override;
698   Status TensorValue(size_t i, Tensor* out_tensor) const override;
699   Status AddTensorFromRunGraphResponse(
700       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
701       size_t i) override;
702   const RunMetadata& metadata() const override;
703   RunMetadata* mutable_metadata() override;
704   errors::Code status_code() const override;
705   const string& status_error_message() const override;
706   void set_status(const Status& status) override;
707 
708  protected:
709   RunStepResponse* get_proto() override;
710 
711  private:
712   RunStepResponse response_;
713 };
714 
715 // Proto-based message wrapper for use on the server side of the RunStep RPC.
716 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
717  public:
718   NonOwnedProtoRunStepResponse(RunStepResponse* response);
719 
720   // MutableRunStepResponseWrapper methods.
721   size_t num_tensors() const override;
722   const string& tensor_name(size_t i) const override;
723   Status TensorValue(size_t i, Tensor* out_tensor) const override;
724   Status AddTensorFromRunGraphResponse(
725       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
726       size_t i) override;
727   const RunMetadata& metadata() const override;
728   RunMetadata* mutable_metadata() override;
729   errors::Code status_code() const override;
730   const string& status_error_message() const override;
731   void set_status(const Status& status) override;
732 
733  protected:
734   RunStepResponse* get_proto() override;
735 
736  private:
737   RunStepResponse* response_;  // Not owned.
738 };
739 
740 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
741                               Tensor* out_tensor);
742 
743 }  // namespace tensorflow
744 
745 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
746