• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
18 
19 #include "tensorflow/core/framework/allocator.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/protobuf/worker.pb.h"
25 
26 namespace tensorflow {
27 
28 class Allocator;
29 class DeviceBase;
30 class TensorProto;
31 
32 // TensorResponse can be used as the destination of an RPC that returns
33 // a RecvTensorResponse.  It efficiently decodes the incoming data
34 // into Tensor contents as well as associated metadata.
35 class TensorResponse {
36  public:
TensorResponse()37   TensorResponse() {}
38 
39   // Reset to initial state.
40   void Clear();
41 
42   // Clear just tensor_ and meta_ members without setting allocation
43   // related members.
44   void ClearTensor();
45 
46   // Initialize memory allocation related members.
47   void InitAlloc(DeviceBase* d, const AllocatorAttributes& aa);
48 
49   // Source provides a way for a particular RPC implementation to provide
50   // received data to ParseFrom.
51   class Source {
52    public:
53     virtual ~Source();
54 
55     // Return the stream that contains the data to be parsed.
56     // Note that this method might be invoked more than once if
57     // ParseFrom needs to fall back to a more expensive parsing method.
58     // Every call must return a stream pointing at the beginning of
59     // the serialized RecvTensorResponse.
60     //
61     // Note that a subsequent call to contents() invalidates previous
62     // results of contents().
63     //
64     // Ownership of the returned stream is retained by the Source and
65     // should not be deleted by the caller.
66     virtual ::tensorflow::protobuf::io::ZeroCopyInputStream* contents() = 0;
67   };
68 
69   // Parse the RecvTensorResponse encoded in the data yielded by
70   // source->contents() into *this.
71   Status ParseFrom(Source* source);
72 
73   // Initialize tensor from *response.
74   // Leaves *response with unspecified contents.
75   Status InitFrom(RecvTensorResponse* response);
76 
77   // Initialize tensor metadata from response and allocate
78   // uninitialized backing storage for actual contents.
79   void InitPartial(const RecvTensorResponse& response,
80                    const AllocationAttributes& allocation_attr);
81 
82   // Return a reference to the parsed tensor.  The tensor will remain
83   // live only until *this is destroyed or modified.
tensor()84   const Tensor& tensor() const { return tensor_; }
85 
86   // Return a reference to the parsed tensor metadata (no contents).
87   // The result will remain live only until *this is destroyed or
88   // modified.
metadata()89   const RecvTensorResponse& metadata() const { return meta_; }
90 
91   // Return pointer to the device hosting the tensor.
device()92   DeviceBase* device() const { return device_; }
93 
94  private:
95   bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
96                              TensorProto* tensor_meta);
97   bool ParseFast(Source* source);
98   bool ParseSlow(Source* source);
99 
100   bool on_host_ = false;
101   DeviceBase* device_ = nullptr;
102   AllocatorAttributes alloc_attrs_;
103   Allocator* allocator_ = nullptr;
104   bool already_used_ = false;
105   Tensor tensor_;
106   RecvTensorResponse meta_;
107 };
108 
109 }  // namespace tensorflow
110 
111 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
112