• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
17 
18 #include "google/protobuf/any.pb.h"
19 
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 
24 namespace tensorflow {
25 
~Source()26 TensorResponse::Source::~Source() {}
27 
Clear()28 void TensorResponse::Clear() {
29   on_host_ = false;
30   device_ = nullptr;
31   alloc_attrs_ = AllocatorAttributes();
32   allocator_ = nullptr;
33   already_used_ = false;
34   ClearTensor();
35 }
36 
ClearTensor()37 void TensorResponse::ClearTensor() {
38   meta_.Clear();
39   tensor_ = Tensor();
40 }
41 
InitAlloc(DeviceBase * d,const AllocatorAttributes & aa)42 void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) {
43   Clear();
44   device_ = d;
45   alloc_attrs_ = aa;
46   const DeviceAttributes& da = d->attributes();
47   if (alloc_attrs_.on_host() || da.device_type() == "CPU") {
48     on_host_ = true;
49   }
50   allocator_ = device_->GetAllocator(alloc_attrs_);
51 }
52 
InitFrom(RecvTensorResponse * response)53 Status TensorResponse::InitFrom(RecvTensorResponse* response) {
54   Status s;
55   meta_.Swap(response);
56   if (on_host_) {
57     if (!tensor_.FromProto(allocator_, meta_.tensor())) {
58       s = errors::InvalidArgument("Cannot parse tensor from response");
59     }
60   } else {
61     s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
62   }
63   {
64     TensorProto empty;
65     meta_.mutable_tensor()->Swap(&empty);
66   }
67   meta_.clear_tensor();
68   return s;
69 }
70 
InitPartial(const RecvTensorResponse & response,const AllocationAttributes & allocation_attr)71 void TensorResponse::InitPartial(const RecvTensorResponse& response,
72                                  const AllocationAttributes& allocation_attr) {
73   // Everything except content is present in *response.  Content will
74   // arrive later; allocate a Tensor with appropriate storage for that
75   // content.
76   meta_ = response;
77   TensorShape shape(meta_.tensor().tensor_shape());
78   Tensor t(allocator_, meta_.tensor().dtype(), shape, allocation_attr);
79   tensor_ = std::move(t);
80 }
81 
ParseFrom(Source * source)82 Status TensorResponse::ParseFrom(Source* source) {
83   if (!on_host_) {
84     protobuf::io::CodedInputStream input(source->contents());
85     input.SetTotalBytesLimit(INT_MAX, INT_MAX);  // Unlimited
86 
87     // Pre-parse into local storage, then delegate to device.
88     if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) {
89       return errors::InvalidArgument("Cannot parse tensor from response");
90     }
91     Status s =
92         device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
93     // Reduce memory usage for big tensors.
94     {
95       TensorProto empty;
96       meta_.mutable_tensor()->Swap(&empty);
97     }
98     meta_.clear_tensor();
99     return s;
100   }
101   if (already_used_) {
102     ClearTensor();
103   }
104   already_used_ = true;
105   if (ParseFast(source)) return Status::OK();
106   meta_.Clear();
107   if (ParseSlow(source)) return Status::OK();
108   return errors::InvalidArgument("Cannot parse tensor from response");
109 }
110 
111 // Define some helper routines for decoding protocol buffer wire format data
112 namespace {
113 // We only need some of the wiretype values for this code
114 enum WireType {
115   WIRETYPE_VARINT = 0,
116   WIRETYPE_LENGTH_DELIMITED = 2,
117 };
GetTagFieldNumber(uint32 tag)118 inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; }
GetTagWireType(uint32 tag)119 inline WireType GetTagWireType(uint32 tag) {
120   return static_cast<WireType>(tag & 0x7);
121 }
122 
ReadVarintSizeAsInt(protobuf::io::CodedInputStream * input,int * result)123 bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) {
124   protobuf_uint64 v;
125   if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) {
126     *result = static_cast<int>(v);
127     return true;
128   } else {
129     return false;
130   }
131 }
132 
ReadNestedMessage(protobuf::io::CodedInputStream * input,protobuf::Message * value)133 bool ReadNestedMessage(protobuf::io::CodedInputStream* input,
134                        protobuf::Message* value) {
135   int length;
136   if (!ReadVarintSizeAsInt(input, &length)) return false;
137   std::pair<protobuf::io::CodedInputStream::Limit, int> p =
138       input->IncrementRecursionDepthAndPushLimit(length);
139   if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false;
140   // Make sure that parsing stopped when the limit was hit, not at an endgroup
141   // tag.
142   return input->DecrementRecursionDepthAndPopLimit(p.first);
143 }
144 
145 }  // namespace
146 
ParseTensorSubmessage(protobuf::io::CodedInputStream * input,TensorProto * tensor_meta)147 bool TensorResponse::ParseTensorSubmessage(
148     protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) {
149   bool seen_tensor_content = false;
150   while (true) {
151     auto p = input->ReadTagWithCutoff(127);
152     int tag = GetTagFieldNumber(p.first);
153     WireType wt = GetTagWireType(p.first);
154     if (!p.second) {
155       bool ok = (tag == 0);
156       if (ok && !seen_tensor_content) {
157         // No tensor content: could be because it's a zero-length tensor
158         TensorShape shape(tensor_meta->tensor_shape());
159         Tensor t(allocator_, tensor_meta->dtype(), shape);
160         tensor_ = std::move(t);
161       }
162       return ok;
163     }
164     switch (tag) {
165       case TensorProto::kDtypeFieldNumber: {
166         uint32 v;
167         if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
168         if (seen_tensor_content) return false;
169         tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v)));
170         if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false;
171         break;
172       }
173       case TensorProto::kTensorShapeFieldNumber: {
174         if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
175             !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape()))
176           return false;
177         if (seen_tensor_content) return false;
178         break;
179       }
180       case TensorProto::kVersionNumberFieldNumber: {
181         uint32 v;
182         if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
183         if (seen_tensor_content) return false;
184         tensor_meta->set_version_number(static_cast<int32>(v));
185         break;
186       }
187       case TensorProto::kTensorContentFieldNumber: {
188         // If we haven't seen the dtype and tensor_shape data first, we can't
189         // deal with this in the fast path.
190         if (seen_tensor_content) return false;
191         if (wt != WIRETYPE_LENGTH_DELIMITED ||
192             !tensor_meta->has_tensor_shape()) {
193           return false;
194         }
195         int num_bytes;
196         if (!ReadVarintSizeAsInt(input, &num_bytes)) return false;
197         seen_tensor_content = true;
198         TensorShape shape(tensor_meta->tensor_shape());
199         Tensor t(allocator_, tensor_meta->dtype(), shape);
200         StringPiece buf = t.tensor_data();
201         if (static_cast<size_t>(num_bytes) != buf.size()) return false;
202         // TODO(jeff,sanjay): Figure out a way to avoid this copy if
203         // the underlying ZeroCopyInputStream data is properly aligned
204         // and compatible with what allocator_ wants.
205         if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes))
206           return false;
207         tensor_ = std::move(t);
208         break;
209       }
210       default: {
211         // Some other tag our fast path code is not prepared to handle.
212         // return false.
213         return false;
214       }
215     }
216   }
217 }
218 
ParseFast(Source * source)219 bool TensorResponse::ParseFast(Source* source) {
220   protobuf::io::CodedInputStream input(source->contents());
221   input.SetTotalBytesLimit(INT_MAX, INT_MAX);  // Unlimited
222   while (true) {
223     auto p = input.ReadTagWithCutoff(127);
224     int tag = GetTagFieldNumber(p.first);
225     WireType wt = GetTagWireType(p.first);
226     if (!p.second) {
227       return (tag == 0);
228     }
229     switch (tag) {
230       case RecvTensorResponse::kTensorFieldNumber: {
231         if (wt != WIRETYPE_LENGTH_DELIMITED) return false;
232 
233         int length;
234         if (!ReadVarintSizeAsInt(&input, &length)) return false;
235         std::pair<protobuf::io::CodedInputStream::Limit, int> p =
236             input.IncrementRecursionDepthAndPushLimit(length);
237         if (p.second < 0 ||
238             !ParseTensorSubmessage(&input, meta_.mutable_tensor())) {
239           return false;
240         }
241         if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
242           return false;
243         }
244         break;
245       }
246       case RecvTensorResponse::kIsDeadFieldNumber: {
247         uint32 v;
248         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
249         meta_.set_is_dead(v != 0);
250         break;
251       }
252       case RecvTensorResponse::kSendStartMicrosFieldNumber: {
253         protobuf_uint64 v;
254         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false;
255         meta_.set_send_start_micros(static_cast<int64>(v));
256         break;
257       }
258       case RecvTensorResponse::kTransportOptionsFieldNumber: {
259         if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
260             !ReadNestedMessage(&input, meta_.mutable_transport_options()))
261           return false;
262         break;
263       }
264       case RecvTensorResponse::kRequireAckFieldNumber: {
265         uint32 v;
266         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
267         meta_.set_require_ack(v != 0);
268         break;
269       }
270       default: {
271         // Unknown tag, so don't handle we can't handle on the fast path
272         return false;
273       }
274     }
275   }
276 
277   return false;
278 }
279 
ParseSlow(Source * source)280 bool TensorResponse::ParseSlow(Source* source) {
281   if (!meta_.ParseFromZeroCopyStream(source->contents())) {
282     return false;
283   }
284 
285   Tensor parsed(meta_.tensor().dtype());
286   if (!parsed.FromProto(allocator_, meta_.tensor())) {
287     return false;
288   }
289   tensor_ = std::move(parsed);
290 
291   // Reduce memory usage for big tensors.
292   {
293     TensorProto empty;
294     meta_.mutable_tensor()->Swap(&empty);
295   }
296   meta_.clear_tensor();
297 
298   return true;
299 }
300 
301 }  // namespace tensorflow
302