• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Inline functions for parsing the protocol buffers wire format.
17 //
18 // These functions have been optimized at the expense of safety.
19 // They are broken out into a separate file for readability but are
20 // not intended for use by clients other than the decode_proto op.
21 //
22 // The calling code in the decode_proto op does some fairly
23 // complicated things to ensure that this code is called
24 // safely. Changes to this code should be thoroughly fuzz tested.
25 
26 #ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
27 #define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
28 
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 namespace internal {
36 
37 using tensorflow::protobuf::internal::WireFormatLite;
38 using tensorflow::protobuf::io::CodedInputStream;
39 using tensorflow::protobuf::io::CodedOutputStream;
40 using tensorflow::protobuf::io::StringOutputStream;
41 
42 // Converts an uint64 to an int64 without loss of information.
43 // Unsigned values greater than INT64_MAX are represented as
44 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned64(uint64 unsigned_value)45 inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
46   // For a detailed explanation of why this works to wrap unsigned ints, see
47   // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
48   // Both if tests should be optimized out.
49   if (unsigned_value <= INT64_MAX) {
50     return static_cast<int64>(unsigned_value);
51   }
52   // The C++ spec allows an architecture where this test is required.
53   if (unsigned_value >= INT64_MIN) {
54     return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
55   }
56   return 0;  // This should never occur.
57 }
58 
59 // Converts an uint32 to an int32 without loss of information.
60 // Unsigned values greater than INT_MAX are represented as
61 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned32(uint32 unsigned_value)62 inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
63   // For a detailed explanation of why this works to wrap unsigned ints, see
64   // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
65   // Both if tests should be optimized out.
66   if (unsigned_value <= INT_MAX) {
67     return static_cast<int32>(unsigned_value);
68   }
69   // The C++ spec allows an architecture where this test is required.
70   if (unsigned_value >= INT_MIN) {
71     return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
72   }
73   return 0;  // This should never occur.
74 }
75 
76 // Reads a single varint32 from a byte array.
77 // It is the caller's responsibility to ensure that there is enough
78 // space in the buffer.
79 // The ok value will be set to false if the buffer does not contain
80 // a valid varint.
81 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
82                                           uint64* value);
83 
84 // Reads a single varint32 from a byte array.
85 // It is the caller's responsibility to ensure that there is enough
86 // space in the buffer.
87 // The ok value will be set to false if the buffer does not contain
88 // a valid varint.
89 // This is slightly less efficient than the private version in
90 // coded_stream.cc but we duplicate less code by calling
91 // the 64 bit version instead of copying the code.
ReadVarint32FromArray(const uint8 * buffer,bool * ok,uint32 * value)92 inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
93                                           uint32* value) {
94   uint64 tmp = 0;
95   const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
96   *value = tmp & 0xffffffff;
97   return buf;
98 }
99 
100 // Reads a single proto field value from a byte array into an array.
101 // The array is part of a Tensor that was allocated by the caller
102 // with type TensorType, while DeclaredType is the proto field type.
103 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
104 const uint8* ReadFromArray(const uint8* buf, TensorType* value);
105 
106 template <>
107 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT32>(
108     const uint8* buf, int64* value) {
109   uint32 temp = 0;
110   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
111   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
112   *value = static_cast<int64>(temp);
113   return buf;
114 }
115 
116 template <>
117 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
118     const uint8* buf, int32* value) {
119   uint32 temp = 0;
120   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
121   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
122   *value = static_cast<int32>(temp);
123   return buf;
124 }
125 
126 template <>
127 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
128     const uint8* buf, int64* value) {
129   uint64 temp = 0;
130   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
131   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
132   *value = WrapUnsignedAsSigned64(temp);
133   return buf;
134 }
135 
136 template <>
137 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>(
138     const uint8* buf, uint64* value) {
139   uint32 temp = 0;
140   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
141   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
142   *value = temp;
143   return buf;
144 }
145 
146 template <>
147 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>(
148     const uint8* buf, uint32* value) {
149   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
150   return ReadVarint32FromArray(buf, &unused_ok, value);
151 }
152 
153 template <>
154 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>(
155     const uint8* buf, uint64* value) {
156   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
157   return ReadVarint64FromArray(buf, &unused_ok, value);
158 }
159 
160 template <>
161 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT32>(
162     const uint8* buf, int64* value) {
163   uint64 temp = 0;
164   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
165   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
166   *value = WireFormatLite::ZigZagDecode32(temp);
167   return buf;
168 }
169 
170 template <>
171 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
172     const uint8* buf, int32* value) {
173   uint32 temp = 0;
174   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
175   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
176   *value = WireFormatLite::ZigZagDecode32(temp);
177   return buf;
178 }
179 
180 template <>
181 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
182     const uint8* buf, int64* value) {
183   uint64 temp = 0;
184   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
185   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
186   *value = WireFormatLite::ZigZagDecode64(temp);
187   return buf;
188 }
189 
190 template <>
191 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>(
192     const uint8* buf, uint64* value) {
193   uint32 temp;
194   buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
195                                                WireFormatLite::TYPE_FIXED32>(
196       buf, &temp);
197   *value = temp;
198   return buf;
199 }
200 
201 template <>
202 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>(
203     const uint8* buf, uint32* value) {
204   uint32 temp;
205   buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
206                                                WireFormatLite::TYPE_FIXED32>(
207       buf, &temp);
208   *value = WrapUnsignedAsSigned32(temp);
209   return buf;
210 }
211 
212 template <>
213 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>(
214     const uint8* buf, uint64* value) {
215   protobuf_uint64 temp;
216   buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
217                                                WireFormatLite::TYPE_FIXED64>(
218       buf, &temp);
219   *value = WrapUnsignedAsSigned64(temp);
220   return buf;
221 }
222 
223 template <>
224 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED32>(
225     const uint8* buf, int64* value) {
226   int32 temp;
227   buf = WireFormatLite::ReadPrimitiveFromArray<int32,
228                                                WireFormatLite::TYPE_SFIXED32>(
229       buf, &temp);
230   *value = temp;
231   return buf;
232 }
233 
234 template <>
235 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
236     const uint8* buf, int32* value) {
237   return WireFormatLite::ReadPrimitiveFromArray<int32,
238                                                 WireFormatLite::TYPE_SFIXED32>(
239       buf, value);
240 }
241 
242 template <>
243 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
244     const uint8* buf, int64* value) {
245   protobuf_int64 temp;
246   buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
247                                                WireFormatLite::TYPE_SFIXED64>(
248       buf, &temp);
249   *value = temp;
250   return buf;
251 }
252 
253 template <>
254 inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
255     const uint8* buf, float* value) {
256   return WireFormatLite::ReadPrimitiveFromArray<float,
257                                                 WireFormatLite::TYPE_FLOAT>(
258       buf, value);
259 }
260 
261 template <>
262 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>(
263     const uint8* buf, double* value) {
264   float temp;
265   buf =
266       WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>(
267           buf, &temp);
268   *value = temp;
269   return buf;
270 }
271 
272 template <>
273 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
274     const uint8* buf, double* value) {
275   return WireFormatLite::ReadPrimitiveFromArray<double,
276                                                 WireFormatLite::TYPE_DOUBLE>(
277       buf, value);
278 }
279 
280 template <>
281 inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
282     const uint8* buf, bool* value) {
283   uint64 temp = 0;
284   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
285   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
286   *value = temp != 0;
287   return buf;
288 }
289 
290 template <>
291 inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
292     const uint8* buf, int* value) {
293   uint32 temp = 0;
294   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
295   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
296   *value = static_cast<int>(temp);
297   return buf;
298 }
299 
300 // Reads packed values from an array.
301 // Stride is set to 1 for repeated fields, and 0 for non-repeated fields
302 // (where any value overwrites previous values).
303 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
ReadPackedPrimitives(const void * bufp,const size_t len,const int index,const int stride,void * datap)304 inline int ReadPackedPrimitives(const void* bufp, const size_t len,
305                                 const int index, const int stride,
306                                 void* datap) {
307   const uint8* buf = reinterpret_cast<const uint8*>(bufp);
308   const uint8* bound = buf + len;
309   TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
310   int count;
311 
312   // This could overrun the bound by stride-1. This is defended
313   // against in the caller, where it ensures that the input buffer
314   // contains complete values.
315   for (count = 0; buf < bound; count += stride) {
316     buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
317   }
318   return count;
319 }
320 
321 // Reads a value of a primitive type field from a serialized proto.
322 // The value is parsed from the serialized format, then static_cast
323 // to the desired type for TensorFlow and stored.
324 template <class ValueType, class TensorType,
325           enum WireFormatLite::FieldType DeclaredType>
ReadPrimitive(CodedInputStream * input,int index,void * data)326 inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
327   ValueType v;
328   if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
329     return errors::DataLoss("Failed reading primitive");
330   }
331 
332   reinterpret_cast<TensorType*>(data)[index] = v;
333   return Status::OK();
334 }
335 
336 // Reads a string, submessage, or other variable-length field from a
337 // serialized proto.
338 // May read all or part of a repeated field.
ReadBytes(CodedInputStream * input,int index,void * datap)339 inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
340   tstring* data = reinterpret_cast<tstring*>(datap) + index;
341 
342   uint32 length;
343   if (!input->ReadVarint32(&length)) {
344     return errors::DataLoss("Failed reading bytes");
345   }
346 
347   data->resize_uninitialized(length);
348 
349   if (!input->ReadRaw(data->data(), length)) {
350     return errors::DataLoss("Failed reading bytes");
351   }
352   return Status::OK();
353 }
354 
355 // Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
356 // as a bytestring.
ReadGroupBytes(CodedInputStream * input,int field_number,int index,void * datap)357 inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
358                              int index, void* datap) {
359   // WireFormatLite::SkipField has an option to emit the
360   // skipped bytes to an output stream. We could do better by implementing our
361   // own scanner but this is simpler for now.
362   // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
363   // on input->IsFlat() == true and using input->GetDirectBufferPointer()
364   // with input->CurrentPosition().
365   tstring* data = reinterpret_cast<tstring*>(datap) + index;
366   // TODO(dero): To mitigate the string to tstring copy, we can implement our
367   // own scanner as described above.  We would first need to obtain the length
368   // in an initial pass and resize/reserve the tstring. But, given that
369   // TYPE_GROUP is deprecated and currently no tests in
370   // tensorflow/python/kernel_tests/proto:decode_proto_op_test target a
371   // TYPE_GROUP tag, we use std::string as a read buffer.
372   string buf;
373   StringOutputStream string_stream(&buf);
374   {
375     CodedOutputStream out(&string_stream);
376     if (!WireFormatLite::SkipField(
377             input,
378             WireFormatLite::MakeTag(field_number,
379                                     WireFormatLite::WIRETYPE_START_GROUP),
380             &out)) {
381       return errors::DataLoss("Failed reading group");
382     }
383   }
384   *data = buf;
385   return Status::OK();
386 }
387 
388 // Reads a single field value from a CodedInputStream into a tensor.
ReadValue(CodedInputStream * input,WireFormatLite::FieldType field_type,int field_number,DataType dtype,int index,void * datap)389 inline Status ReadValue(CodedInputStream* input,
390                         WireFormatLite::FieldType field_type, int field_number,
391                         DataType dtype, int index, void* datap) {
392   // Dispatch to the appropriately typed field reader based on the schema type.
393   switch (field_type) {
394     case WireFormatLite::TYPE_DOUBLE:
395       return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
396           input, index, datap);
397     case WireFormatLite::TYPE_FLOAT:
398       switch (dtype) {
399         case DataType::DT_DOUBLE:
400           return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
401               input, index, datap);
402         case DataType::DT_FLOAT:
403           return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
404               input, index, datap);
405         default:
406           return errors::DataLoss("Failed reading TYPE_FLOAT for ",
407                                   DataTypeString(dtype));
408       }
409     case WireFormatLite::TYPE_INT64:
410       return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
411           input, index, datap);
412     case WireFormatLite::TYPE_UINT64:
413       return ReadPrimitive<protobuf_uint64, uint64,
414                            WireFormatLite::TYPE_UINT64>(input, index, datap);
415     case WireFormatLite::TYPE_INT32:
416       switch (dtype) {
417         case DataType::DT_INT64:
418           return ReadPrimitive<int32, int64, WireFormatLite::TYPE_INT32>(
419               input, index, datap);
420         case DataType::DT_INT32:
421           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
422               input, index, datap);
423         default:
424           return errors::DataLoss("Failed reading TYPE_INT32 for ",
425                                   DataTypeString(dtype));
426       }
427     case WireFormatLite::TYPE_FIXED64:
428       return ReadPrimitive<protobuf_uint64, uint64,
429                            WireFormatLite::TYPE_FIXED64>(input, index, datap);
430     case WireFormatLite::TYPE_FIXED32:
431       switch (dtype) {
432         case DataType::DT_UINT64:
433           return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>(
434               input, index, datap);
435         case DataType::DT_UINT32:
436           return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>(
437               input, index, datap);
438         default:
439           return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
440                                   DataTypeString(dtype));
441       }
442     case WireFormatLite::TYPE_BOOL:
443       return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
444                                                                   datap);
445     case WireFormatLite::TYPE_STRING:
446       return ReadBytes(input, index, datap);
447     case WireFormatLite::TYPE_GROUP:
448       return ReadGroupBytes(input, field_number, index, datap);
449     case WireFormatLite::TYPE_MESSAGE:
450       return ReadBytes(input, index, datap);
451     case WireFormatLite::TYPE_BYTES:
452       return ReadBytes(input, index, datap);
453     case WireFormatLite::TYPE_UINT32:
454       switch (dtype) {
455         case DataType::DT_UINT64:
456           return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>(
457               input, index, datap);
458         case DataType::DT_UINT32:
459           return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>(
460               input, index, datap);
461         default:
462           return errors::DataLoss("Failed reading TYPE_UINT32 for ",
463                                   DataTypeString(dtype));
464       }
465     case WireFormatLite::TYPE_ENUM:
466       return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
467           input, index, datap);
468     case WireFormatLite::TYPE_SFIXED32:
469       switch (dtype) {
470         case DataType::DT_INT64:
471           return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SFIXED32>(
472               input, index, datap);
473         case DataType::DT_INT32:
474           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
475               input, index, datap);
476         default:
477           return errors::DataLoss("Failed reading TYPE_SFIXED32 for ",
478                                   DataTypeString(dtype));
479       }
480     case WireFormatLite::TYPE_SFIXED64:
481       return ReadPrimitive<protobuf_int64, int64,
482                            WireFormatLite::TYPE_SFIXED64>(input, index, datap);
483     case WireFormatLite::TYPE_SINT32:
484       switch (dtype) {
485         case DataType::DT_INT64:
486           return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SINT32>(
487               input, index, datap);
488         case DataType::DT_INT32:
489           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
490               input, index, datap);
491         default:
492           return errors::DataLoss("Failed reading TYPE_SINT32 for ",
493                                   DataTypeString(dtype));
494       }
495     case WireFormatLite::TYPE_SINT64:
496       return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
497           input, index, datap);
498       // default: intentionally omitted in order to enable static checking.
499   }
500   // Unreachable.
501   return errors::DataLoss("Failed reading unknown wire type");
502 }
503 
504 // Reads and stores a length-delimited list of values.
ReadPackedFromArray(const void * buf,size_t buf_size,const WireFormatLite::FieldType field_type,const int field_number,const DataType dtype,const int stride,int * index,void * data)505 inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
506                                   const WireFormatLite::FieldType field_type,
507                                   const int field_number, const DataType dtype,
508                                   const int stride, int* index, void* data) {
509   // Dispatch to the appropriately typed field reader based on the schema type.
510   switch (field_type) {
511     case WireFormatLite::TYPE_DOUBLE:
512       *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
513           buf, buf_size, *index, stride, data);
514       return Status::OK();
515     case WireFormatLite::TYPE_FLOAT:
516       switch (dtype) {
517         case DataType::DT_DOUBLE:
518           *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>(
519               buf, buf_size, *index, stride, data);
520           return Status::OK();
521         case DataType::DT_FLOAT:
522           *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
523               buf, buf_size, *index, stride, data);
524           return Status::OK();
525         default:
526           return errors::DataLoss("Failed reading TYPE_FLOAT for ",
527                                   DataTypeString(dtype));
528       }
529     case WireFormatLite::TYPE_INT64:
530       *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
531           buf, buf_size, *index, stride, data);
532       return Status::OK();
533     case WireFormatLite::TYPE_UINT64:
534       *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>(
535           buf, buf_size, *index, stride, data);
536       return Status::OK();
537     case WireFormatLite::TYPE_INT32:
538       switch (dtype) {
539         case DataType::DT_INT64:
540           *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT32>(
541               buf, buf_size, *index, stride, data);
542           return Status::OK();
543         case DataType::DT_INT32:
544           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
545               buf, buf_size, *index, stride, data);
546           return Status::OK();
547         default:
548           return errors::DataLoss("Failed reading TYPE_INT32 for ",
549                                   DataTypeString(dtype));
550       }
551     case WireFormatLite::TYPE_FIXED64:
552       *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>(
553           buf, buf_size, *index, stride, data);
554       return Status::OK();
555     case WireFormatLite::TYPE_FIXED32:
556       switch (dtype) {
557         case DataType::DT_UINT64:
558           *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>(
559               buf, buf_size, *index, stride, data);
560           return Status::OK();
561         case DataType::DT_UINT32:
562           *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>(
563               buf, buf_size, *index, stride, data);
564           return Status::OK();
565         default:
566           return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
567                                   DataTypeString(dtype));
568       }
569     case WireFormatLite::TYPE_BOOL:
570       *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
571           buf, buf_size, *index, stride, data);
572       return Status::OK();
573     case WireFormatLite::TYPE_STRING:
574     case WireFormatLite::TYPE_GROUP:
575     case WireFormatLite::TYPE_MESSAGE:
576     case WireFormatLite::TYPE_BYTES:
577       return errors::DataLoss("Non-primitive type encountered as packed");
578     case WireFormatLite::TYPE_UINT32:
579       switch (dtype) {
580         case DataType::DT_UINT64:
581           *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>(
582               buf, buf_size, *index, stride, data);
583           return Status::OK();
584         case DataType::DT_UINT32:
585           *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>(
586               buf, buf_size, *index, stride, data);
587           return Status::OK();
588         default:
589           return errors::DataLoss("Failed reading TYPE_UINT32 for ",
590                                   DataTypeString(dtype));
591       }
592     case WireFormatLite::TYPE_ENUM:
593       *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
594           buf, buf_size, *index, stride, data);
595       return Status::OK();
596     case WireFormatLite::TYPE_SFIXED32:
597       switch (dtype) {
598         case DataType::DT_INT64:
599           *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED32>(
600               buf, buf_size, *index, stride, data);
601           return Status::OK();
602         case DataType::DT_INT32:
603           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
604               buf, buf_size, *index, stride, data);
605           return Status::OK();
606         default:
607           return errors::DataLoss("Failed reading TYPE_INT32 for ",
608                                   DataTypeString(dtype));
609       }
610     case WireFormatLite::TYPE_SFIXED64:
611       *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
612           buf, buf_size, *index, stride, data);
613       return Status::OK();
614 
615     case WireFormatLite::TYPE_SINT32:
616       switch (dtype) {
617         case DataType::DT_INT64:
618           *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT32>(
619               buf, buf_size, *index, stride, data);
620           return Status::OK();
621         case DataType::DT_INT32:
622           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
623               buf, buf_size, *index, stride, data);
624           return Status::OK();
625         default:
626           return errors::DataLoss("Failed reading TYPE_SINT32 for ",
627                                   DataTypeString(dtype));
628       }
629     case WireFormatLite::TYPE_SINT64:
630       *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
631           buf, buf_size, *index, stride, data);
632       return Status::OK();
633       // default: intentionally omitted in order to enable static checking.
634   }
635   // Unreachable.
636   return errors::DataLoss("Failed reading unknown wire type");
637 }
638 
639 // Reads a varint from the given buffer, write it to *value, and return the
640 // new buffer pointer.
641 // This was copied from coded_stream.cc where it is private.
642 // Important: This routine may read as much as kMaxVarintBytes from
643 // the buffer. It is the caller's responsibility to make sure that there is
644 // enough space in the buffer.
ReadVarint64FromArray(const uint8 * buffer,bool * ok,uint64 * value)645 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
646                                           uint64* value) {
647   const uint8* ptr = buffer;
648   uint32 b;
649 
650   // Splitting into 32-bit pieces gives better performance on 32-bit
651   // processors.
652   uint32 part0 = 0, part1 = 0, part2 = 0;
653 
654   b = *(ptr++);
655   part0 = b;
656   if (!(b & 0x80)) goto done;
657   part0 -= 0x80;
658   b = *(ptr++);
659   part0 += b << 7;
660   if (!(b & 0x80)) goto done;
661   part0 -= 0x80 << 7;
662   b = *(ptr++);
663   part0 += b << 14;
664   if (!(b & 0x80)) goto done;
665   part0 -= 0x80 << 14;
666   b = *(ptr++);
667   part0 += b << 21;
668   if (!(b & 0x80)) goto done;
669   part0 -= 0x80 << 21;
670   b = *(ptr++);
671   part1 = b;
672   if (!(b & 0x80)) goto done;
673   part1 -= 0x80;
674   b = *(ptr++);
675   part1 += b << 7;
676   if (!(b & 0x80)) goto done;
677   part1 -= 0x80 << 7;
678   b = *(ptr++);
679   part1 += b << 14;
680   if (!(b & 0x80)) goto done;
681   part1 -= 0x80 << 14;
682   b = *(ptr++);
683   part1 += b << 21;
684   if (!(b & 0x80)) goto done;
685   part1 -= 0x80 << 21;
686   b = *(ptr++);
687   part2 = b;
688   if (!(b & 0x80)) goto done;
689   part2 -= 0x80;
690   b = *(ptr++);
691   part2 += b << 7;
692   if (!(b & 0x80)) goto done;
693   // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
694 
695   // We have overrun the maximum size of a varint (10 bytes).  Assume
696   // the data is corrupt.
697   *ok = false;
698   return ptr;
699 
700 done:
701   *ok = true;
702   *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
703            (static_cast<uint64>(part2) << 56);
704   return ptr;
705 }
706 
707 }  // namespace internal
708 }  // namespace tensorflow
709 
710 #endif  // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
711