1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Inline functions for parsing the protocol buffers wire format.
17 //
18 // These functions have been optimized at the expense of safety.
19 // They are broken out into a separate file for readability but are
20 // not intended for use by clients other than the decode_proto op.
21 //
22 // The calling code in the decode_proto op does some fairly
23 // complicated things to ensure that this code is called
24 // safely. Changes to this code should be thoroughly fuzz tested.
25
26 #ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
27 #define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
28
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace tensorflow {
35 namespace internal {
36
37 using tensorflow::protobuf::internal::WireFormatLite;
38 using tensorflow::protobuf::io::CodedInputStream;
39 using tensorflow::protobuf::io::CodedOutputStream;
40 using tensorflow::protobuf::io::StringOutputStream;
41
42 // Converts an uint64 to an int64 without loss of information.
43 // Unsigned values greater than INT64_MAX are represented as
44 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned64(uint64 unsigned_value)45 inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
46 // For a detailed explanation of why this works to wrap unsigned ints, see
47 // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
48 // Both if tests should be optimized out.
49 if (unsigned_value <= INT64_MAX) {
50 return static_cast<int64>(unsigned_value);
51 }
52 // The C++ spec allows an architecture where this test is required.
53 if (unsigned_value >= INT64_MIN) {
54 return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
55 }
56 return 0; // This should never occur.
57 }
58
59 // Converts an uint32 to an int32 without loss of information.
60 // Unsigned values greater than INT_MAX are represented as
61 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned32(uint32 unsigned_value)62 inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
63 // For a detailed explanation of why this works to wrap unsigned ints, see
64 // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
65 // Both if tests should be optimized out.
66 if (unsigned_value <= INT_MAX) {
67 return static_cast<int32>(unsigned_value);
68 }
69 // The C++ spec allows an architecture where this test is required.
70 if (unsigned_value >= INT_MIN) {
71 return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
72 }
73 return 0; // This should never occur.
74 }
75
76 // Reads a single varint32 from a byte array.
77 // It is the caller's responsibility to ensure that there is enough
78 // space in the buffer.
79 // The ok value will be set to false if the buffer does not contain
80 // a valid varint.
81 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
82 uint64* value);
83
84 // Reads a single varint32 from a byte array.
85 // It is the caller's responsibility to ensure that there is enough
86 // space in the buffer.
87 // The ok value will be set to false if the buffer does not contain
88 // a valid varint.
89 // This is slightly less efficient than the private version in
90 // coded_stream.cc but we duplicate less code by calling
91 // the 64 bit version instead of copying the code.
ReadVarint32FromArray(const uint8 * buffer,bool * ok,uint32 * value)92 inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
93 uint32* value) {
94 uint64 tmp = 0;
95 const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
96 *value = tmp & 0xffffffff;
97 return buf;
98 }
99
100 // Reads a single proto field value from a byte array into an array.
101 // The array is part of a Tensor that was allocated by the caller
102 // with type TensorType, while DeclaredType is the proto field type.
103 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
104 const uint8* ReadFromArray(const uint8* buf, TensorType* value);
105
106 template <>
107 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT32>(
108 const uint8* buf, int64* value) {
109 uint32 temp = 0;
110 bool unused_ok; // The Counting pass would have failed if this were corrupt.
111 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
112 *value = static_cast<int64>(temp);
113 return buf;
114 }
115
116 template <>
117 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
118 const uint8* buf, int32* value) {
119 uint32 temp = 0;
120 bool unused_ok; // The Counting pass would have failed if this were corrupt.
121 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
122 *value = static_cast<int32>(temp);
123 return buf;
124 }
125
126 template <>
127 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
128 const uint8* buf, int64* value) {
129 uint64 temp = 0;
130 bool unused_ok; // The Counting pass would have failed if this were corrupt.
131 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
132 *value = WrapUnsignedAsSigned64(temp);
133 return buf;
134 }
135
136 template <>
137 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>(
138 const uint8* buf, uint64* value) {
139 uint32 temp = 0;
140 bool unused_ok; // The Counting pass would have failed if this were corrupt.
141 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
142 *value = temp;
143 return buf;
144 }
145
146 template <>
147 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>(
148 const uint8* buf, uint32* value) {
149 bool unused_ok; // The Counting pass would have failed if this were corrupt.
150 return ReadVarint32FromArray(buf, &unused_ok, value);
151 }
152
153 template <>
154 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>(
155 const uint8* buf, uint64* value) {
156 bool unused_ok; // The Counting pass would have failed if this were corrupt.
157 return ReadVarint64FromArray(buf, &unused_ok, value);
158 }
159
160 template <>
161 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT32>(
162 const uint8* buf, int64* value) {
163 uint64 temp = 0;
164 bool unused_ok; // The Counting pass would have failed if this were corrupt.
165 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
166 *value = WireFormatLite::ZigZagDecode32(temp);
167 return buf;
168 }
169
170 template <>
171 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
172 const uint8* buf, int32* value) {
173 uint32 temp = 0;
174 bool unused_ok; // The Counting pass would have failed if this were corrupt.
175 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
176 *value = WireFormatLite::ZigZagDecode32(temp);
177 return buf;
178 }
179
180 template <>
181 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
182 const uint8* buf, int64* value) {
183 uint64 temp = 0;
184 bool unused_ok; // The Counting pass would have failed if this were corrupt.
185 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
186 *value = WireFormatLite::ZigZagDecode64(temp);
187 return buf;
188 }
189
190 template <>
191 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>(
192 const uint8* buf, uint64* value) {
193 uint32 temp;
194 buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
195 WireFormatLite::TYPE_FIXED32>(
196 buf, &temp);
197 *value = temp;
198 return buf;
199 }
200
201 template <>
202 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>(
203 const uint8* buf, uint32* value) {
204 uint32 temp;
205 buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
206 WireFormatLite::TYPE_FIXED32>(
207 buf, &temp);
208 *value = WrapUnsignedAsSigned32(temp);
209 return buf;
210 }
211
212 template <>
213 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>(
214 const uint8* buf, uint64* value) {
215 protobuf_uint64 temp;
216 buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
217 WireFormatLite::TYPE_FIXED64>(
218 buf, &temp);
219 *value = WrapUnsignedAsSigned64(temp);
220 return buf;
221 }
222
223 template <>
224 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED32>(
225 const uint8* buf, int64* value) {
226 int32 temp;
227 buf = WireFormatLite::ReadPrimitiveFromArray<int32,
228 WireFormatLite::TYPE_SFIXED32>(
229 buf, &temp);
230 *value = temp;
231 return buf;
232 }
233
234 template <>
235 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
236 const uint8* buf, int32* value) {
237 return WireFormatLite::ReadPrimitiveFromArray<int32,
238 WireFormatLite::TYPE_SFIXED32>(
239 buf, value);
240 }
241
242 template <>
243 inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
244 const uint8* buf, int64* value) {
245 protobuf_int64 temp;
246 buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
247 WireFormatLite::TYPE_SFIXED64>(
248 buf, &temp);
249 *value = temp;
250 return buf;
251 }
252
253 template <>
254 inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
255 const uint8* buf, float* value) {
256 return WireFormatLite::ReadPrimitiveFromArray<float,
257 WireFormatLite::TYPE_FLOAT>(
258 buf, value);
259 }
260
261 template <>
262 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>(
263 const uint8* buf, double* value) {
264 float temp;
265 buf =
266 WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>(
267 buf, &temp);
268 *value = temp;
269 return buf;
270 }
271
272 template <>
273 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
274 const uint8* buf, double* value) {
275 return WireFormatLite::ReadPrimitiveFromArray<double,
276 WireFormatLite::TYPE_DOUBLE>(
277 buf, value);
278 }
279
280 template <>
281 inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
282 const uint8* buf, bool* value) {
283 uint64 temp = 0;
284 bool unused_ok; // The Counting pass would have failed if this were corrupt.
285 buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
286 *value = temp != 0;
287 return buf;
288 }
289
290 template <>
291 inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
292 const uint8* buf, int* value) {
293 uint32 temp = 0;
294 bool unused_ok; // The Counting pass would have failed if this were corrupt.
295 buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
296 *value = static_cast<int>(temp);
297 return buf;
298 }
299
300 // Reads packed values from an array.
301 // Stride is set to 1 for repeated fields, and 0 for non-repeated fields
302 // (where any value overwrites previous values).
303 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
ReadPackedPrimitives(const void * bufp,const size_t len,const int index,const int stride,void * datap)304 inline int ReadPackedPrimitives(const void* bufp, const size_t len,
305 const int index, const int stride,
306 void* datap) {
307 const uint8* buf = reinterpret_cast<const uint8*>(bufp);
308 const uint8* bound = buf + len;
309 TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
310 int count;
311
312 // This could overrun the bound by stride-1. This is defended
313 // against in the caller, where it ensures that the input buffer
314 // contains complete values.
315 for (count = 0; buf < bound; count += stride) {
316 buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
317 }
318 return count;
319 }
320
321 // Reads a value of a primitive type field from a serialized proto.
322 // The value is parsed from the serialized format, then static_cast
323 // to the desired type for TensorFlow and stored.
324 template <class ValueType, class TensorType,
325 enum WireFormatLite::FieldType DeclaredType>
ReadPrimitive(CodedInputStream * input,int index,void * data)326 inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
327 ValueType v;
328 if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
329 return errors::DataLoss("Failed reading primitive");
330 }
331
332 reinterpret_cast<TensorType*>(data)[index] = v;
333 return Status::OK();
334 }
335
336 // Reads a string, submessage, or other variable-length field from a
337 // serialized proto.
338 // May read all or part of a repeated field.
ReadBytes(CodedInputStream * input,int index,void * datap)339 inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
340 tstring* data = reinterpret_cast<tstring*>(datap) + index;
341
342 uint32 length;
343 if (!input->ReadVarint32(&length)) {
344 return errors::DataLoss("Failed reading bytes");
345 }
346
347 data->resize_uninitialized(length);
348
349 if (!input->ReadRaw(data->data(), length)) {
350 return errors::DataLoss("Failed reading bytes");
351 }
352 return Status::OK();
353 }
354
355 // Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
356 // as a bytestring.
ReadGroupBytes(CodedInputStream * input,int field_number,int index,void * datap)357 inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
358 int index, void* datap) {
359 // WireFormatLite::SkipField has an option to emit the
360 // skipped bytes to an output stream. We could do better by implementing our
361 // own scanner but this is simpler for now.
362 // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
363 // on input->IsFlat() == true and using input->GetDirectBufferPointer()
364 // with input->CurrentPosition().
365 tstring* data = reinterpret_cast<tstring*>(datap) + index;
366 // TODO(dero): To mitigate the string to tstring copy, we can implement our
367 // own scanner as described above. We would first need to obtain the length
368 // in an initial pass and resize/reserve the tstring. But, given that
369 // TYPE_GROUP is deprecated and currently no tests in
370 // tensorflow/python/kernel_tests/proto:decode_proto_op_test target a
371 // TYPE_GROUP tag, we use std::string as a read buffer.
372 string buf;
373 StringOutputStream string_stream(&buf);
374 {
375 CodedOutputStream out(&string_stream);
376 if (!WireFormatLite::SkipField(
377 input,
378 WireFormatLite::MakeTag(field_number,
379 WireFormatLite::WIRETYPE_START_GROUP),
380 &out)) {
381 return errors::DataLoss("Failed reading group");
382 }
383 }
384 *data = buf;
385 return Status::OK();
386 }
387
388 // Reads a single field value from a CodedInputStream into a tensor.
ReadValue(CodedInputStream * input,WireFormatLite::FieldType field_type,int field_number,DataType dtype,int index,void * datap)389 inline Status ReadValue(CodedInputStream* input,
390 WireFormatLite::FieldType field_type, int field_number,
391 DataType dtype, int index, void* datap) {
392 // Dispatch to the appropriately typed field reader based on the schema type.
393 switch (field_type) {
394 case WireFormatLite::TYPE_DOUBLE:
395 return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
396 input, index, datap);
397 case WireFormatLite::TYPE_FLOAT:
398 switch (dtype) {
399 case DataType::DT_DOUBLE:
400 return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
401 input, index, datap);
402 case DataType::DT_FLOAT:
403 return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
404 input, index, datap);
405 default:
406 return errors::DataLoss("Failed reading TYPE_FLOAT for ",
407 DataTypeString(dtype));
408 }
409 case WireFormatLite::TYPE_INT64:
410 return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
411 input, index, datap);
412 case WireFormatLite::TYPE_UINT64:
413 return ReadPrimitive<protobuf_uint64, uint64,
414 WireFormatLite::TYPE_UINT64>(input, index, datap);
415 case WireFormatLite::TYPE_INT32:
416 switch (dtype) {
417 case DataType::DT_INT64:
418 return ReadPrimitive<int32, int64, WireFormatLite::TYPE_INT32>(
419 input, index, datap);
420 case DataType::DT_INT32:
421 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
422 input, index, datap);
423 default:
424 return errors::DataLoss("Failed reading TYPE_INT32 for ",
425 DataTypeString(dtype));
426 }
427 case WireFormatLite::TYPE_FIXED64:
428 return ReadPrimitive<protobuf_uint64, uint64,
429 WireFormatLite::TYPE_FIXED64>(input, index, datap);
430 case WireFormatLite::TYPE_FIXED32:
431 switch (dtype) {
432 case DataType::DT_UINT64:
433 return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>(
434 input, index, datap);
435 case DataType::DT_UINT32:
436 return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>(
437 input, index, datap);
438 default:
439 return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
440 DataTypeString(dtype));
441 }
442 case WireFormatLite::TYPE_BOOL:
443 return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
444 datap);
445 case WireFormatLite::TYPE_STRING:
446 return ReadBytes(input, index, datap);
447 case WireFormatLite::TYPE_GROUP:
448 return ReadGroupBytes(input, field_number, index, datap);
449 case WireFormatLite::TYPE_MESSAGE:
450 return ReadBytes(input, index, datap);
451 case WireFormatLite::TYPE_BYTES:
452 return ReadBytes(input, index, datap);
453 case WireFormatLite::TYPE_UINT32:
454 switch (dtype) {
455 case DataType::DT_UINT64:
456 return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>(
457 input, index, datap);
458 case DataType::DT_UINT32:
459 return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>(
460 input, index, datap);
461 default:
462 return errors::DataLoss("Failed reading TYPE_UINT32 for ",
463 DataTypeString(dtype));
464 }
465 case WireFormatLite::TYPE_ENUM:
466 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
467 input, index, datap);
468 case WireFormatLite::TYPE_SFIXED32:
469 switch (dtype) {
470 case DataType::DT_INT64:
471 return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SFIXED32>(
472 input, index, datap);
473 case DataType::DT_INT32:
474 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
475 input, index, datap);
476 default:
477 return errors::DataLoss("Failed reading TYPE_SFIXED32 for ",
478 DataTypeString(dtype));
479 }
480 case WireFormatLite::TYPE_SFIXED64:
481 return ReadPrimitive<protobuf_int64, int64,
482 WireFormatLite::TYPE_SFIXED64>(input, index, datap);
483 case WireFormatLite::TYPE_SINT32:
484 switch (dtype) {
485 case DataType::DT_INT64:
486 return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SINT32>(
487 input, index, datap);
488 case DataType::DT_INT32:
489 return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
490 input, index, datap);
491 default:
492 return errors::DataLoss("Failed reading TYPE_SINT32 for ",
493 DataTypeString(dtype));
494 }
495 case WireFormatLite::TYPE_SINT64:
496 return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
497 input, index, datap);
498 // default: intentionally omitted in order to enable static checking.
499 }
500 // Unreachable.
501 return errors::DataLoss("Failed reading unknown wire type");
502 }
503
504 // Reads and stores a length-delimited list of values.
ReadPackedFromArray(const void * buf,size_t buf_size,const WireFormatLite::FieldType field_type,const int field_number,const DataType dtype,const int stride,int * index,void * data)505 inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
506 const WireFormatLite::FieldType field_type,
507 const int field_number, const DataType dtype,
508 const int stride, int* index, void* data) {
509 // Dispatch to the appropriately typed field reader based on the schema type.
510 switch (field_type) {
511 case WireFormatLite::TYPE_DOUBLE:
512 *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
513 buf, buf_size, *index, stride, data);
514 return Status::OK();
515 case WireFormatLite::TYPE_FLOAT:
516 switch (dtype) {
517 case DataType::DT_DOUBLE:
518 *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>(
519 buf, buf_size, *index, stride, data);
520 return Status::OK();
521 case DataType::DT_FLOAT:
522 *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
523 buf, buf_size, *index, stride, data);
524 return Status::OK();
525 default:
526 return errors::DataLoss("Failed reading TYPE_FLOAT for ",
527 DataTypeString(dtype));
528 }
529 case WireFormatLite::TYPE_INT64:
530 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
531 buf, buf_size, *index, stride, data);
532 return Status::OK();
533 case WireFormatLite::TYPE_UINT64:
534 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>(
535 buf, buf_size, *index, stride, data);
536 return Status::OK();
537 case WireFormatLite::TYPE_INT32:
538 switch (dtype) {
539 case DataType::DT_INT64:
540 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT32>(
541 buf, buf_size, *index, stride, data);
542 return Status::OK();
543 case DataType::DT_INT32:
544 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
545 buf, buf_size, *index, stride, data);
546 return Status::OK();
547 default:
548 return errors::DataLoss("Failed reading TYPE_INT32 for ",
549 DataTypeString(dtype));
550 }
551 case WireFormatLite::TYPE_FIXED64:
552 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>(
553 buf, buf_size, *index, stride, data);
554 return Status::OK();
555 case WireFormatLite::TYPE_FIXED32:
556 switch (dtype) {
557 case DataType::DT_UINT64:
558 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>(
559 buf, buf_size, *index, stride, data);
560 return Status::OK();
561 case DataType::DT_UINT32:
562 *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>(
563 buf, buf_size, *index, stride, data);
564 return Status::OK();
565 default:
566 return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
567 DataTypeString(dtype));
568 }
569 case WireFormatLite::TYPE_BOOL:
570 *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
571 buf, buf_size, *index, stride, data);
572 return Status::OK();
573 case WireFormatLite::TYPE_STRING:
574 case WireFormatLite::TYPE_GROUP:
575 case WireFormatLite::TYPE_MESSAGE:
576 case WireFormatLite::TYPE_BYTES:
577 return errors::DataLoss("Non-primitive type encountered as packed");
578 case WireFormatLite::TYPE_UINT32:
579 switch (dtype) {
580 case DataType::DT_UINT64:
581 *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>(
582 buf, buf_size, *index, stride, data);
583 return Status::OK();
584 case DataType::DT_UINT32:
585 *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>(
586 buf, buf_size, *index, stride, data);
587 return Status::OK();
588 default:
589 return errors::DataLoss("Failed reading TYPE_UINT32 for ",
590 DataTypeString(dtype));
591 }
592 case WireFormatLite::TYPE_ENUM:
593 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
594 buf, buf_size, *index, stride, data);
595 return Status::OK();
596 case WireFormatLite::TYPE_SFIXED32:
597 switch (dtype) {
598 case DataType::DT_INT64:
599 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED32>(
600 buf, buf_size, *index, stride, data);
601 return Status::OK();
602 case DataType::DT_INT32:
603 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
604 buf, buf_size, *index, stride, data);
605 return Status::OK();
606 default:
607 return errors::DataLoss("Failed reading TYPE_INT32 for ",
608 DataTypeString(dtype));
609 }
610 case WireFormatLite::TYPE_SFIXED64:
611 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
612 buf, buf_size, *index, stride, data);
613 return Status::OK();
614
615 case WireFormatLite::TYPE_SINT32:
616 switch (dtype) {
617 case DataType::DT_INT64:
618 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT32>(
619 buf, buf_size, *index, stride, data);
620 return Status::OK();
621 case DataType::DT_INT32:
622 *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
623 buf, buf_size, *index, stride, data);
624 return Status::OK();
625 default:
626 return errors::DataLoss("Failed reading TYPE_SINT32 for ",
627 DataTypeString(dtype));
628 }
629 case WireFormatLite::TYPE_SINT64:
630 *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
631 buf, buf_size, *index, stride, data);
632 return Status::OK();
633 // default: intentionally omitted in order to enable static checking.
634 }
635 // Unreachable.
636 return errors::DataLoss("Failed reading unknown wire type");
637 }
638
639 // Reads a varint from the given buffer, write it to *value, and return the
640 // new buffer pointer.
641 // This was copied from coded_stream.cc where it is private.
642 // Important: This routine may read as much as kMaxVarintBytes from
643 // the buffer. It is the caller's responsibility to make sure that there is
644 // enough space in the buffer.
ReadVarint64FromArray(const uint8 * buffer,bool * ok,uint64 * value)645 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
646 uint64* value) {
647 const uint8* ptr = buffer;
648 uint32 b;
649
650 // Splitting into 32-bit pieces gives better performance on 32-bit
651 // processors.
652 uint32 part0 = 0, part1 = 0, part2 = 0;
653
654 b = *(ptr++);
655 part0 = b;
656 if (!(b & 0x80)) goto done;
657 part0 -= 0x80;
658 b = *(ptr++);
659 part0 += b << 7;
660 if (!(b & 0x80)) goto done;
661 part0 -= 0x80 << 7;
662 b = *(ptr++);
663 part0 += b << 14;
664 if (!(b & 0x80)) goto done;
665 part0 -= 0x80 << 14;
666 b = *(ptr++);
667 part0 += b << 21;
668 if (!(b & 0x80)) goto done;
669 part0 -= 0x80 << 21;
670 b = *(ptr++);
671 part1 = b;
672 if (!(b & 0x80)) goto done;
673 part1 -= 0x80;
674 b = *(ptr++);
675 part1 += b << 7;
676 if (!(b & 0x80)) goto done;
677 part1 -= 0x80 << 7;
678 b = *(ptr++);
679 part1 += b << 14;
680 if (!(b & 0x80)) goto done;
681 part1 -= 0x80 << 14;
682 b = *(ptr++);
683 part1 += b << 21;
684 if (!(b & 0x80)) goto done;
685 part1 -= 0x80 << 21;
686 b = *(ptr++);
687 part2 = b;
688 if (!(b & 0x80)) goto done;
689 part2 -= 0x80;
690 b = *(ptr++);
691 part2 += b << 7;
692 if (!(b & 0x80)) goto done;
693 // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
694
695 // We have overrun the maximum size of a varint (10 bytes). Assume
696 // the data is corrupt.
697 *ok = false;
698 return ptr;
699
700 done:
701 *ok = true;
702 *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
703 (static_cast<uint64>(part2) << 56);
704 return ptr;
705 }
706
707 } // namespace internal
708 } // namespace tensorflow
709
710 #endif // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
711