1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
17
18 #include "google/protobuf/any.pb.h"
19
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23
24 namespace tensorflow {
25
~Source()26 TensorResponse::Source::~Source() {}
27
Clear()28 void TensorResponse::Clear() {
29 on_host_ = false;
30 device_ = nullptr;
31 alloc_attrs_ = AllocatorAttributes();
32 allocator_ = nullptr;
33 already_used_ = false;
34 ClearTensor();
35 }
36
ClearTensor()37 void TensorResponse::ClearTensor() {
38 meta_.Clear();
39 tensor_ = Tensor();
40 }
41
InitAlloc(DeviceBase * d,const AllocatorAttributes & aa)42 void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) {
43 Clear();
44 device_ = d;
45 alloc_attrs_ = aa;
46 const DeviceAttributes& da = d->attributes();
47 if (alloc_attrs_.on_host() || da.device_type() == "CPU") {
48 on_host_ = true;
49 }
50 allocator_ = device_->GetAllocator(alloc_attrs_);
51 }
52
InitFrom(RecvTensorResponse * response)53 Status TensorResponse::InitFrom(RecvTensorResponse* response) {
54 Status s;
55 meta_.Swap(response);
56 if (on_host_) {
57 if (!tensor_.FromProto(allocator_, meta_.tensor())) {
58 s = errors::InvalidArgument("Cannot parse tensor from response");
59 }
60 } else {
61 s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
62 }
63 {
64 TensorProto empty;
65 meta_.mutable_tensor()->Swap(&empty);
66 }
67 meta_.clear_tensor();
68 return s;
69 }
70
InitPartial(const RecvTensorResponse & response,const AllocationAttributes & allocation_attr)71 void TensorResponse::InitPartial(const RecvTensorResponse& response,
72 const AllocationAttributes& allocation_attr) {
73 // Everything except content is present in *response. Content will
74 // arrive later; allocate a Tensor with appropriate storage for that
75 // content.
76 meta_ = response;
77 TensorShape shape(meta_.tensor().tensor_shape());
78 Tensor t(allocator_, meta_.tensor().dtype(), shape, allocation_attr);
79 tensor_ = std::move(t);
80 }
81
ParseFrom(Source * source)82 Status TensorResponse::ParseFrom(Source* source) {
83 if (!on_host_) {
84 protobuf::io::CodedInputStream input(source->contents());
85 input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited
86
87 // Pre-parse into local storage, then delegate to device.
88 if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) {
89 return errors::InvalidArgument("Cannot parse tensor from response");
90 }
91 Status s =
92 device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
93 // Reduce memory usage for big tensors.
94 {
95 TensorProto empty;
96 meta_.mutable_tensor()->Swap(&empty);
97 }
98 meta_.clear_tensor();
99 return s;
100 }
101 if (already_used_) {
102 ClearTensor();
103 }
104 already_used_ = true;
105 if (ParseFast(source)) return Status::OK();
106 meta_.Clear();
107 if (ParseSlow(source)) return Status::OK();
108 return errors::InvalidArgument("Cannot parse tensor from response");
109 }
110
111 // Define some helper routines for decoding protocol buffer wire format data
112 namespace {
113 // We only need some of the wiretype values for this code
114 enum WireType {
115 WIRETYPE_VARINT = 0,
116 WIRETYPE_LENGTH_DELIMITED = 2,
117 };
GetTagFieldNumber(uint32 tag)118 inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; }
GetTagWireType(uint32 tag)119 inline WireType GetTagWireType(uint32 tag) {
120 return static_cast<WireType>(tag & 0x7);
121 }
122
ReadVarintSizeAsInt(protobuf::io::CodedInputStream * input,int * result)123 bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) {
124 protobuf_uint64 v;
125 if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) {
126 *result = static_cast<int>(v);
127 return true;
128 } else {
129 return false;
130 }
131 }
132
ReadNestedMessage(protobuf::io::CodedInputStream * input,protobuf::Message * value)133 bool ReadNestedMessage(protobuf::io::CodedInputStream* input,
134 protobuf::Message* value) {
135 int length;
136 if (!ReadVarintSizeAsInt(input, &length)) return false;
137 std::pair<protobuf::io::CodedInputStream::Limit, int> p =
138 input->IncrementRecursionDepthAndPushLimit(length);
139 if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false;
140 // Make sure that parsing stopped when the limit was hit, not at an endgroup
141 // tag.
142 return input->DecrementRecursionDepthAndPopLimit(p.first);
143 }
144
145 } // namespace
146
ParseTensorSubmessage(protobuf::io::CodedInputStream * input,TensorProto * tensor_meta)147 bool TensorResponse::ParseTensorSubmessage(
148 protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) {
149 bool seen_tensor_content = false;
150 while (true) {
151 auto p = input->ReadTagWithCutoff(127);
152 int tag = GetTagFieldNumber(p.first);
153 WireType wt = GetTagWireType(p.first);
154 if (!p.second) {
155 bool ok = (tag == 0);
156 if (ok && !seen_tensor_content) {
157 // No tensor content: could be because it's a zero-length tensor
158 TensorShape shape(tensor_meta->tensor_shape());
159 Tensor t(allocator_, tensor_meta->dtype(), shape);
160 tensor_ = std::move(t);
161 }
162 return ok;
163 }
164 switch (tag) {
165 case TensorProto::kDtypeFieldNumber: {
166 uint32 v;
167 if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
168 if (seen_tensor_content) return false;
169 tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v)));
170 if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false;
171 break;
172 }
173 case TensorProto::kTensorShapeFieldNumber: {
174 if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
175 !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape()))
176 return false;
177 if (seen_tensor_content) return false;
178 break;
179 }
180 case TensorProto::kVersionNumberFieldNumber: {
181 uint32 v;
182 if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
183 if (seen_tensor_content) return false;
184 tensor_meta->set_version_number(static_cast<int32>(v));
185 break;
186 }
187 case TensorProto::kTensorContentFieldNumber: {
188 // If we haven't seen the dtype and tensor_shape data first, we can't
189 // deal with this in the fast path.
190 if (seen_tensor_content) return false;
191 if (wt != WIRETYPE_LENGTH_DELIMITED ||
192 !tensor_meta->has_tensor_shape()) {
193 return false;
194 }
195 int num_bytes;
196 if (!ReadVarintSizeAsInt(input, &num_bytes)) return false;
197 seen_tensor_content = true;
198 TensorShape shape(tensor_meta->tensor_shape());
199 Tensor t(allocator_, tensor_meta->dtype(), shape);
200 StringPiece buf = t.tensor_data();
201 if (static_cast<size_t>(num_bytes) != buf.size()) return false;
202 // TODO(jeff,sanjay): Figure out a way to avoid this copy if
203 // the underlying ZeroCopyInputStream data is properly aligned
204 // and compatible with what allocator_ wants.
205 if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes))
206 return false;
207 tensor_ = std::move(t);
208 break;
209 }
210 default: {
211 // Some other tag our fast path code is not prepared to handle.
212 // return false.
213 return false;
214 }
215 }
216 }
217 }
218
ParseFast(Source * source)219 bool TensorResponse::ParseFast(Source* source) {
220 protobuf::io::CodedInputStream input(source->contents());
221 input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited
222 while (true) {
223 auto p = input.ReadTagWithCutoff(127);
224 int tag = GetTagFieldNumber(p.first);
225 WireType wt = GetTagWireType(p.first);
226 if (!p.second) {
227 return (tag == 0);
228 }
229 switch (tag) {
230 case RecvTensorResponse::kTensorFieldNumber: {
231 if (wt != WIRETYPE_LENGTH_DELIMITED) return false;
232
233 int length;
234 if (!ReadVarintSizeAsInt(&input, &length)) return false;
235 std::pair<protobuf::io::CodedInputStream::Limit, int> p =
236 input.IncrementRecursionDepthAndPushLimit(length);
237 if (p.second < 0 ||
238 !ParseTensorSubmessage(&input, meta_.mutable_tensor())) {
239 return false;
240 }
241 if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
242 return false;
243 }
244 break;
245 }
246 case RecvTensorResponse::kIsDeadFieldNumber: {
247 uint32 v;
248 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
249 meta_.set_is_dead(v != 0);
250 break;
251 }
252 case RecvTensorResponse::kSendStartMicrosFieldNumber: {
253 protobuf_uint64 v;
254 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false;
255 meta_.set_send_start_micros(static_cast<int64>(v));
256 break;
257 }
258 case RecvTensorResponse::kTransportOptionsFieldNumber: {
259 if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
260 !ReadNestedMessage(&input, meta_.mutable_transport_options()))
261 return false;
262 break;
263 }
264 case RecvTensorResponse::kRequireAckFieldNumber: {
265 uint32 v;
266 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
267 meta_.set_require_ack(v != 0);
268 break;
269 }
270 default: {
271 // Unknown tag, so don't handle we can't handle on the fast path
272 return false;
273 }
274 }
275 }
276
277 return false;
278 }
279
ParseSlow(Source * source)280 bool TensorResponse::ParseSlow(Source* source) {
281 if (!meta_.ParseFromZeroCopyStream(source->contents())) {
282 return false;
283 }
284
285 Tensor parsed(meta_.tensor().dtype());
286 if (!parsed.FromProto(allocator_, meta_.tensor())) {
287 return false;
288 }
289 tensor_ = std::move(parsed);
290
291 // Reduce memory usage for big tensors.
292 {
293 TensorProto empty;
294 meta_.mutable_tensor()->Swap(&empty);
295 }
296 meta_.clear_tensor();
297
298 return true;
299 }
300
301 } // namespace tensorflow
302