1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Inline functions for parsing the protocol buffers wire format.
17 //
18 // These functions have been optimized at the expense of safety.
19 // They are broken out into a separate file for readability but are
20 // not intended for use by clients other than the decode_proto op.
21 //
22 // The calling code in the decode_proto op does some fairly
23 // complicated things to ensure that this code is called
24 // safely. Changes to this code should be thoroughly fuzz tested.
25
26 #ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
27 #define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
28
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace tensorflow {
35 namespace internal {
36
37 using tensorflow::protobuf::internal::WireFormatLite;
38 using tensorflow::protobuf::io::CodedInputStream;
39 using tensorflow::protobuf::io::CodedOutputStream;
40 using tensorflow::protobuf::io::StringOutputStream;
41
42 // Converts an uint64 to an int64 without loss of information.
43 // Unsigned values greater than INT64_MAX are represented as
44 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned64(uint64 unsigned_value)45 inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
46 // For a detailed explanation of why this works to wrap unsigned ints, see
47 // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
48 // Both if tests should be optimized out.
49 if (unsigned_value <= INT64_MAX) {
50 return static_cast<int64>(unsigned_value);
51 }
52 // The C++ spec allows an architecture where this test is required.
53 if (unsigned_value >= INT64_MIN) {
54 return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
55 }
56 return 0; // This should never occur.
57 }
58
59 // Converts an uint32 to an int32 without loss of information.
60 // Unsigned values greater than INT_MAX are represented as
61 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned32(uint32 unsigned_value)62 inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
63 // For a detailed explanation of why this works to wrap unsigned ints, see
64 // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
65 // Both if tests should be optimized out.
66 if (unsigned_value <= INT_MAX) {
67 return static_cast<int32>(unsigned_value);
68 }
69 // The C++ spec allows an architecture where this test is required.
70 if (unsigned_value >= INT_MIN) {
71 return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
72 }
73 return 0; // This should never occur.
74 }
75
76 // Reads a single varint32 from a byte array.
77 // It is the caller's responsibility to ensure that there is enough
78 // space in the buffer.
79 // The ok value will be set to false if the buffer does not contain
80 // a valid varint.
81 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
82 uint64* value);
83
84 // Reads a single varint32 from a byte array.
85 // It is the caller's responsibility to ensure that there is enough
86 // space in the buffer.
87 // The ok value will be set to false if the buffer does not contain
88 // a valid varint.
89 // This is slightly less efficient than the private version in
90 // coded_stream.cc but we duplicate less code by calling
91 // the 64 bit version instead of copying the code.
ReadVarint32FromArray(const uint8 * buffer,bool * ok,uint32 * value)92 inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
93 uint32* value) {
94 uint64 tmp = 0;
95 const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
96 *value = tmp & 0xffffffff;
97 return buf;
98 }
99
100 // Reads a single proto field value from a byte array into an array.
101 // The array is part of a Tensor that was allocated by the caller
102 // with type TensorType, while DeclaredType is the proto field type.
103 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
104 const uint8* ReadFromArray(const uint8* buf, TensorType* value);
105
106 template <>
107 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT32>(
108 const uint8* buf, int64* value) {
109 uint32 temp = 0;
110 bool unused_ok; // The Counting pass would have failed if this were corrupt.
111 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
112 *value = static_cast<int64>(temp);
113 return buf;
114 }
115
116 template <>
117 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
118 const uint8* buf, int32* value) {
119 uint32 temp = 0;
120 bool unused_ok; // The Counting pass would have failed if this were corrupt.
121 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
122 *value = static_cast<int32>(temp);
123 return buf;
124 }
125
126 template <>
127 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
128 const uint8* buf, int64* value) {
129 uint64 temp = 0;
130 bool unused_ok; // The Counting pass would have failed if this were corrupt.
131 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
132 *value = WrapUnsignedAsSigned64(temp);
133 return buf;
134 }
135
136 template <>
137 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>(
138 const uint8* buf, uint64* value) {
139 uint32 temp = 0;
140 bool unused_ok; // The Counting pass would have failed if this were corrupt.
141 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
142 *value = temp;
143 return buf;
144 }
145
146 template <>
147 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>(
148 const uint8* buf, uint32* value) {
149 bool unused_ok; // The Counting pass would have failed if this were corrupt.
150 return ReadVarint32FromArray(buf, &unused_ok, value);
151 }
152
153 template <>
154 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>(
155 const uint8* buf, uint64* value) {
156 bool unused_ok; // The Counting pass would have failed if this were corrupt.
157 return ReadVarint64FromArray(buf, &unused_ok, value);
158 }
159
160 template <>
161 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT32>(
162 const uint8* buf, int64* value) {
163 uint64 temp = 0;
164 bool unused_ok; // The Counting pass would have failed if this were corrupt.
165 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
166 *value = WireFormatLite::ZigZagDecode32(temp);
167 return buf;
168 }
169
170 template <>
171 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
172 const uint8* buf, int32* value) {
173 uint32 temp = 0;
174 bool unused_ok; // The Counting pass would have failed if this were corrupt.
175 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
176 *value = WireFormatLite::ZigZagDecode32(temp);
177 return buf;
178 }
179
180 template <>
181 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
182 const uint8* buf, int64* value) {
183 uint64 temp = 0;
184 bool unused_ok; // The Counting pass would have failed if this were corrupt.
185 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
186 *value = WireFormatLite::ZigZagDecode64(temp);
187 return buf;
188 }
189
190 template <>
191 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>(
192 const uint8* buf, uint64* value) {
193 uint32 temp;
194 buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
195 WireFormatLite::TYPE_FIXED32>(
196 buf, &temp);
197 *value = temp;
198 return buf;
199 }
200
201 template <>
202 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>(
203 const uint8* buf, uint32* value) {
204 uint32 temp;
205 buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
206 WireFormatLite::TYPE_FIXED32>(
207 buf, &temp);
208 *value = WrapUnsignedAsSigned32(temp);
209 return buf;
210 }
211
212 template <>
213 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>(
214 const uint8* buf, uint64* value) {
215 protobuf_uint64 temp;
216 buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
217 WireFormatLite::TYPE_FIXED64>(
218 buf, &temp);
219 *value = WrapUnsignedAsSigned64(temp);
220 return buf;
221 }
222
223 template <>
224 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED32>(
225 const uint8* buf, int64* value) {
226 int32 temp;
227 buf = WireFormatLite::ReadPrimitiveFromArray<int32,
228 WireFormatLite::TYPE_SFIXED32>(
229 buf, &temp);
230 *value = temp;
231 return buf;
232 }
233
234 template <>
235 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
236 const uint8* buf, int32* value) {
237 return WireFormatLite::ReadPrimitiveFromArray<int32,
238 WireFormatLite::TYPE_SFIXED32>(
239 buf, value);
240 }
241
242 template <>
243 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
244 const uint8* buf, int64* value) {
245 protobuf_int64 temp;
246 buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
247 WireFormatLite::TYPE_SFIXED64>(
248 buf, &temp);
249 *value = temp;
250 return buf;
251 }
252
253 template <>
254 inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
255 const uint8* buf, float* value) {
256 return WireFormatLite::ReadPrimitiveFromArray<float,
257 WireFormatLite::TYPE_FLOAT>(
258 buf, value);
259 }
260
261 template <>
262 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>(
263 const uint8* buf, double* value) {
264 float temp;
265 buf =
266 WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>(
267 buf, &temp);
268 *value = temp;
269 return buf;
270 }
271
272 template <>
273 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
274 const uint8* buf, double* value) {
275 return WireFormatLite::ReadPrimitiveFromArray<double,
276 WireFormatLite::TYPE_DOUBLE>(
277 buf, value);
278 }
279
280 template <>
281 inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
282 const uint8* buf, bool* value) {
283 uint64 temp = 0;
284 bool unused_ok; // The Counting pass would have failed if this were corrupt.
285 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
286 *value = temp != 0;
287 return buf;
288 }
289
290 template <>
291 inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
292 const uint8* buf, int* value) {
293 uint32 temp = 0;
294 bool unused_ok; // The Counting pass would have failed if this were corrupt.
295 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
296 *value = static_cast<int>(temp);
297 return buf;
298 }
299
300 // Reads packed values from an array.
301 // Stride is set to 1 for repeated fields, and 0 for non-repeated fields
302 // (where any value overwrites previous values).
303 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
ReadPackedPrimitives(const void * bufp,const size_t len,const int index,const int stride,void * datap)304 inline int ReadPackedPrimitives(const void* bufp, const size_t len,
305 const int index, const int stride,
306 void* datap) {
307 const uint8* buf = reinterpret_cast<const uint8*>(bufp);
308 const uint8* bound = buf + len;
309 TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
310 int count;
311
312 // This could overrun the bound by stride-1. This is defended
313 // against in the caller, where it ensures that the input buffer
314 // contains complete values.
315 for (count = 0; buf < bound; count += stride) {
316 buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
317 }
318 return count;
319 }
320
321 // Reads a value of a primitive type field from a serialized proto.
322 // The value is parsed from the serialized format, then static_cast
323 // to the desired type for TensorFlow and stored.
324 template <class ValueType, class TensorType,
325 enum WireFormatLite::FieldType DeclaredType>
ReadPrimitive(CodedInputStream * input,int index,void * data)326 inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
327 ValueType v;
328 if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
329 return errors::DataLoss("Failed reading primitive");
330 }
331
332 reinterpret_cast<TensorType*>(data)[index] = v;
333 return Status::OK();
334 }
335
336 // Reads a string, submessage, or other variable-length field from a
337 // serialized proto.
338 // May read all or part of a repeated field.
ReadBytes(CodedInputStream * input,int index,void * datap)339 inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
340 string* data = reinterpret_cast<string*>(datap) + index;
341 if (!WireFormatLite::ReadBytes(input, data)) {
342 return errors::DataLoss("Failed reading bytes");
343 }
344 return Status::OK();
345 }
346
347 // Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
348 // as a bytestring.
ReadGroupBytes(CodedInputStream * input,int field_number,int index,void * datap)349 inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
350 int index, void* datap) {
351 // WireFormatLite::SkipField has an option to emit the
352 // skipped bytes to an output stream. We could do better by implementing our
353 // own scanner but this is simpler for now.
354 // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
355 // on input->IsFlat() == true and using input->GetDirectBufferPointer()
356 // with input->CurrentPosition().
357 string* data = reinterpret_cast<string*>(datap) + index;
358 StringOutputStream string_stream(data);
359 CodedOutputStream out(&string_stream);
360 if (!WireFormatLite::SkipField(
361 input,
362 WireFormatLite::MakeTag(field_number,
363 WireFormatLite::WIRETYPE_START_GROUP),
364 &out)) {
365 return errors::DataLoss("Failed reading group");
366 }
367 return Status::OK();
368 }
369
370 // Reads a single field value from a CodedInputStream into a tensor.
ReadValue(CodedInputStream * input,WireFormatLite::FieldType field_type,int field_number,DataType dtype,int index,void * datap)371 inline Status ReadValue(CodedInputStream* input,
372 WireFormatLite::FieldType field_type, int field_number,
373 DataType dtype, int index, void* datap) {
374 // Dispatch to the appropriately typed field reader based on the schema type.
375 switch (field_type) {
376 case WireFormatLite::TYPE_DOUBLE:
377 return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
378 input, index, datap);
379 case WireFormatLite::TYPE_FLOAT:
380 switch (dtype) {
381 case DataType::DT_DOUBLE:
382 return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
383 input, index, datap);
384 case DataType::DT_FLOAT:
385 return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
386 input, index, datap);
387 default:
388 return errors::DataLoss("Failed reading TYPE_FLOAT for ",
389 DataTypeString(dtype));
390 }
391 case WireFormatLite::TYPE_INT64:
392 return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
393 input, index, datap);
394 case WireFormatLite::TYPE_UINT64:
395 return ReadPrimitive<protobuf_uint64, uint64,
396 WireFormatLite::TYPE_UINT64>(input, index, datap);
397 case WireFormatLite::TYPE_INT32:
398 switch (dtype) {
399 case DataType::DT_INT64:
400 return ReadPrimitive<int32, int64, WireFormatLite::TYPE_INT32>(
401 input, index, datap);
402 case DataType::DT_INT32:
403 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
404 input, index, datap);
405 default:
406 return errors::DataLoss("Failed reading TYPE_INT32 for ",
407 DataTypeString(dtype));
408 }
409 case WireFormatLite::TYPE_FIXED64:
410 return ReadPrimitive<protobuf_uint64, uint64,
411 WireFormatLite::TYPE_FIXED64>(input, index, datap);
412 case WireFormatLite::TYPE_FIXED32:
413 switch (dtype) {
414 case DataType::DT_UINT64:
415 return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>(
416 input, index, datap);
417 case DataType::DT_UINT32:
418 return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>(
419 input, index, datap);
420 default:
421 return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
422 DataTypeString(dtype));
423 }
424 case WireFormatLite::TYPE_BOOL:
425 return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
426 datap);
427 case WireFormatLite::TYPE_STRING:
428 return ReadBytes(input, index, datap);
429 case WireFormatLite::TYPE_GROUP:
430 return ReadGroupBytes(input, field_number, index, datap);
431 case WireFormatLite::TYPE_MESSAGE:
432 return ReadBytes(input, index, datap);
433 case WireFormatLite::TYPE_BYTES:
434 return ReadBytes(input, index, datap);
435 case WireFormatLite::TYPE_UINT32:
436 switch (dtype) {
437 case DataType::DT_UINT64:
438 return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>(
439 input, index, datap);
440 case DataType::DT_UINT32:
441 return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>(
442 input, index, datap);
443 default:
444 return errors::DataLoss("Failed reading TYPE_UINT32 for ",
445 DataTypeString(dtype));
446 }
447 case WireFormatLite::TYPE_ENUM:
448 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
449 input, index, datap);
450 case WireFormatLite::TYPE_SFIXED32:
451 switch (dtype) {
452 case DataType::DT_INT64:
453 return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SFIXED32>(
454 input, index, datap);
455 case DataType::DT_INT32:
456 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
457 input, index, datap);
458 default:
459 return errors::DataLoss("Failed reading TYPE_SFIXED32 for ",
460 DataTypeString(dtype));
461 }
462 case WireFormatLite::TYPE_SFIXED64:
463 return ReadPrimitive<protobuf_int64, int64,
464 WireFormatLite::TYPE_SFIXED64>(input, index, datap);
465 case WireFormatLite::TYPE_SINT32:
466 switch (dtype) {
467 case DataType::DT_INT64:
468 return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SINT32>(
469 input, index, datap);
470 case DataType::DT_INT32:
471 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
472 input, index, datap);
473 default:
474 return errors::DataLoss("Failed reading TYPE_SINT32 for ",
475 DataTypeString(dtype));
476 }
477 case WireFormatLite::TYPE_SINT64:
478 return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
479 input, index, datap);
480 // default: intentionally omitted in order to enable static checking.
481 }
482 // Unreachable.
483 return errors::DataLoss("Failed reading unknown wire type");
484 }
485
486 // Reads and stores a length-delimited list of values.
ReadPackedFromArray(const void * buf,size_t buf_size,const WireFormatLite::FieldType field_type,const int field_number,const DataType dtype,const int stride,int * index,void * data)487 inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
488 const WireFormatLite::FieldType field_type,
489 const int field_number, const DataType dtype,
490 const int stride, int* index, void* data) {
491 // Dispatch to the appropriately typed field reader based on the schema type.
492 switch (field_type) {
493 case WireFormatLite::TYPE_DOUBLE:
494 *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
495 buf, buf_size, *index, stride, data);
496 return Status::OK();
497 case WireFormatLite::TYPE_FLOAT:
498 switch (dtype) {
499 case DataType::DT_DOUBLE:
500 *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>(
501 buf, buf_size, *index, stride, data);
502 return Status::OK();
503 case DataType::DT_FLOAT:
504 *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
505 buf, buf_size, *index, stride, data);
506 return Status::OK();
507 default:
508 return errors::DataLoss("Failed reading TYPE_FLOAT for ",
509 DataTypeString(dtype));
510 }
511 case WireFormatLite::TYPE_INT64:
512 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
513 buf, buf_size, *index, stride, data);
514 return Status::OK();
515 case WireFormatLite::TYPE_UINT64:
516 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>(
517 buf, buf_size, *index, stride, data);
518 return Status::OK();
519 case WireFormatLite::TYPE_INT32:
520 switch (dtype) {
521 case DataType::DT_INT64:
522 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT32>(
523 buf, buf_size, *index, stride, data);
524 return Status::OK();
525 case DataType::DT_INT32:
526 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
527 buf, buf_size, *index, stride, data);
528 return Status::OK();
529 default:
530 return errors::DataLoss("Failed reading TYPE_INT32 for ",
531 DataTypeString(dtype));
532 }
533 case WireFormatLite::TYPE_FIXED64:
534 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>(
535 buf, buf_size, *index, stride, data);
536 return Status::OK();
537 case WireFormatLite::TYPE_FIXED32:
538 switch (dtype) {
539 case DataType::DT_UINT64:
540 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>(
541 buf, buf_size, *index, stride, data);
542 return Status::OK();
543 case DataType::DT_UINT32:
544 *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>(
545 buf, buf_size, *index, stride, data);
546 return Status::OK();
547 default:
548 return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
549 DataTypeString(dtype));
550 }
551 case WireFormatLite::TYPE_BOOL:
552 *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
553 buf, buf_size, *index, stride, data);
554 return Status::OK();
555 case WireFormatLite::TYPE_STRING:
556 case WireFormatLite::TYPE_GROUP:
557 case WireFormatLite::TYPE_MESSAGE:
558 case WireFormatLite::TYPE_BYTES:
559 return errors::DataLoss("Non-primitive type encountered as packed");
560 case WireFormatLite::TYPE_UINT32:
561 switch (dtype) {
562 case DataType::DT_UINT64:
563 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>(
564 buf, buf_size, *index, stride, data);
565 return Status::OK();
566 case DataType::DT_UINT32:
567 *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>(
568 buf, buf_size, *index, stride, data);
569 return Status::OK();
570 default:
571 return errors::DataLoss("Failed reading TYPE_UINT32 for ",
572 DataTypeString(dtype));
573 }
574 case WireFormatLite::TYPE_ENUM:
575 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
576 buf, buf_size, *index, stride, data);
577 return Status::OK();
578 case WireFormatLite::TYPE_SFIXED32:
579 switch (dtype) {
580 case DataType::DT_INT64:
581 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED32>(
582 buf, buf_size, *index, stride, data);
583 return Status::OK();
584 case DataType::DT_INT32:
585 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
586 buf, buf_size, *index, stride, data);
587 return Status::OK();
588 default:
589 return errors::DataLoss("Failed reading TYPE_INT32 for ",
590 DataTypeString(dtype));
591 }
592 case WireFormatLite::TYPE_SFIXED64:
593 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
594 buf, buf_size, *index, stride, data);
595 return Status::OK();
596
597 case WireFormatLite::TYPE_SINT32:
598 switch (dtype) {
599 case DataType::DT_INT64:
600 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT32>(
601 buf, buf_size, *index, stride, data);
602 return Status::OK();
603 case DataType::DT_INT32:
604 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
605 buf, buf_size, *index, stride, data);
606 return Status::OK();
607 default:
608 return errors::DataLoss("Failed reading TYPE_SINT32 for ",
609 DataTypeString(dtype));
610 }
611 case WireFormatLite::TYPE_SINT64:
612 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
613 buf, buf_size, *index, stride, data);
614 return Status::OK();
615 // default: intentionally omitted in order to enable static checking.
616 }
617 // Unreachable.
618 return errors::DataLoss("Failed reading unknown wire type");
619 }
620
621 // Reads a varint from the given buffer, write it to *value, and return the
622 // new buffer pointer.
623 // This was copied from coded_stream.cc where it is private.
624 // Important: This routine may read as much as kMaxVarintBytes from
625 // the buffer. It is the caller's responsibility to make sure that there is
626 // enough space in the buffer.
ReadVarint64FromArray(const uint8 * buffer,bool * ok,uint64 * value)627 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
628 uint64* value) {
629 const uint8* ptr = buffer;
630 uint32 b;
631
632 // Splitting into 32-bit pieces gives better performance on 32-bit
633 // processors.
634 uint32 part0 = 0, part1 = 0, part2 = 0;
635
636 b = *(ptr++);
637 part0 = b;
638 if (!(b & 0x80)) goto done;
639 part0 -= 0x80;
640 b = *(ptr++);
641 part0 += b << 7;
642 if (!(b & 0x80)) goto done;
643 part0 -= 0x80 << 7;
644 b = *(ptr++);
645 part0 += b << 14;
646 if (!(b & 0x80)) goto done;
647 part0 -= 0x80 << 14;
648 b = *(ptr++);
649 part0 += b << 21;
650 if (!(b & 0x80)) goto done;
651 part0 -= 0x80 << 21;
652 b = *(ptr++);
653 part1 = b;
654 if (!(b & 0x80)) goto done;
655 part1 -= 0x80;
656 b = *(ptr++);
657 part1 += b << 7;
658 if (!(b & 0x80)) goto done;
659 part1 -= 0x80 << 7;
660 b = *(ptr++);
661 part1 += b << 14;
662 if (!(b & 0x80)) goto done;
663 part1 -= 0x80 << 14;
664 b = *(ptr++);
665 part1 += b << 21;
666 if (!(b & 0x80)) goto done;
667 part1 -= 0x80 << 21;
668 b = *(ptr++);
669 part2 = b;
670 if (!(b & 0x80)) goto done;
671 part2 -= 0x80;
672 b = *(ptr++);
673 part2 += b << 7;
674 if (!(b & 0x80)) goto done;
675 // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
676
677 // We have overrun the maximum size of a varint (10 bytes). Assume
678 // the data is corrupt.
679 *ok = false;
680 return ptr;
681
682 done:
683 *ok = true;
684 *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
685 (static_cast<uint64>(part2) << 56);
686 return ptr;
687 }
688
689 } // namespace internal
690 } // namespace tensorflow
691
692 #endif // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
693