1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // DecodeProto is a TensorFlow op which extracts arbitrary fields from protos
17 // serialized as strings.
18 //
19 // See docs in ../ops/decode_proto_op.cc.
20 //
21 // This implementation reads the serialized format using a handful of calls from
22 // the WireFormatLite API used by generated proto code. WireFormatLite is marked
23 // as an "internal" proto API but is widely used in practice and highly unlikely
24 // to change. This will be much faster than the previous implementation based on
25 // constructing a temporary dynamic message in memory and using the proto
26 // reflection api to read it. It can be used with any proto whose descriptors
27 // are available at runtime but should be competitive in speed with approaches
28 // that compile in the proto definitions.
29
30 #include <memory>
31 #include <string>
32 #include <vector>
33
34 #include "absl/container/flat_hash_map.h"
35 #include "third_party/eigen3/Eigen/Core"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/tensor_types.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/util/proto/decode.h"
43 #include "tensorflow/core/util/proto/descriptors.h"
44 #include "tensorflow/core/util/proto/proto_utils.h"
45 #include "tensorflow/core/util/ptr_util.h"
46
47 namespace tensorflow {
48 namespace {
49
50 using ::tensorflow::MakeUnique;
51 using ::tensorflow::protobuf::Descriptor;
52 using ::tensorflow::protobuf::DescriptorPool;
53 using ::tensorflow::protobuf::DynamicMessageFactory;
54 using ::tensorflow::protobuf::FieldDescriptor;
55 using ::tensorflow::protobuf::Message;
56 using ::tensorflow::protobuf::TextFormat;
57 using ::tensorflow::protobuf::internal::WireFormatLite;
58 using ::tensorflow::protobuf::io::CodedInputStream;
59
60 const bool kFailOnDecodeError = true;
61
62 // Used to store the default value of a protocol message field, casted to the
63 // type of the output tensor.
64 //
65 // TODO(paskin): Use absl::variant once TensorFlow gets absl dependencies.
66 struct DefaultValue {
67 DataType dtype = DataType::DT_INVALID;
68 union Value {
69 bool v_bool; // DT_BOOL
70 double v_double; // DT_DOUBLE
71 float v_float; // DT_FLOAT
72 int8 v_int8; // DT_INT8
73 int32 v_int32; // DT_INT32
74 int64 v_int64; // DT_INT64
75 const char* v_string; // DT_STRING
76 uint8 v_uint8; // DT_UINT8
77 uint8 v_uint32; // DT_UINT32
78 uint8 v_uint64; // DT_UINT64
79 };
80 Value value;
81 };
82
83 // Initializes a DefaultValue object. This generic template handles numeric
84 // types and strings are handled by a template specialization below.
85 //
86 // Args:
87 // dtype: the type of the output tensor
88 // value: the default value as obtained from the FieldDescriptor
89 // result: the object to initialize
90 template <typename T>
InitDefaultValue(DataType dtype,const T value,DefaultValue * result)91 Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) {
92 result->dtype = dtype;
93 switch (dtype) {
94 case DT_BOOL:
95 result->value.v_bool = static_cast<bool>(value);
96 break;
97 case DT_DOUBLE:
98 result->value.v_double = static_cast<double>(value);
99 break;
100 case DT_FLOAT:
101 result->value.v_float = static_cast<float>(value);
102 break;
103 case DT_INT8:
104 result->value.v_int8 = static_cast<int8>(value);
105 break;
106 case DT_INT32:
107 result->value.v_int32 = static_cast<int32>(value);
108 break;
109 case DT_INT64:
110 result->value.v_int64 = static_cast<int64>(value);
111 break;
112 case DT_UINT8:
113 result->value.v_uint8 = static_cast<uint8>(value);
114 break;
115 case DT_UINT32:
116 result->value.v_uint32 = static_cast<uint32>(value);
117 break;
118 case DT_UINT64:
119 result->value.v_uint64 = static_cast<uint64>(value);
120 break;
121 default:
122 // We should never get here, given the type checking that occurs earlier.
123 return errors::Internal(
124 "Cannot initialize default value for unsupported type: ",
125 DataTypeString(dtype));
126 }
127 return Status::OK();
128 }
129
130 template <>
InitDefaultValue(DataType dtype,const char * value,DefaultValue * result)131 Status InitDefaultValue(DataType dtype, const char* value,
132 DefaultValue* result) {
133 // These are sanity checks that should never trigger given the code that
134 // leads here.
135 if (TF_PREDICT_FALSE(dtype != DT_STRING)) {
136 return errors::InvalidArgument(
137 "Cannot cast field to anything but DT_STRING");
138 }
139 if (TF_PREDICT_FALSE(value == nullptr)) {
140 return errors::InvalidArgument("Null default string value.");
141 }
142 result->dtype = DT_STRING;
143 result->value.v_string = value;
144 return Status::OK();
145 }
146
147 // Initializes a default value from the output data type and the field
148 // descriptor.
InitDefaultValueFromFieldDescriptor(DataType dtype,const FieldDescriptor * field_desc,DefaultValue * result)149 Status InitDefaultValueFromFieldDescriptor(DataType dtype,
150 const FieldDescriptor* field_desc,
151 DefaultValue* result) {
152 switch (field_desc->type()) {
153 case WireFormatLite::TYPE_DOUBLE:
154 return InitDefaultValue(dtype, field_desc->default_value_double(),
155 result);
156 case WireFormatLite::TYPE_FLOAT:
157 return InitDefaultValue(dtype, field_desc->default_value_float(), result);
158 case WireFormatLite::TYPE_INT64:
159 case WireFormatLite::TYPE_SINT64:
160 case WireFormatLite::TYPE_SFIXED64:
161 return InitDefaultValue(dtype, field_desc->default_value_int64(), result);
162 case WireFormatLite::TYPE_FIXED64:
163 case WireFormatLite::TYPE_UINT64:
164 return InitDefaultValue(dtype, field_desc->default_value_uint64(),
165 result);
166 case WireFormatLite::TYPE_ENUM:
167 case WireFormatLite::TYPE_INT32:
168 case WireFormatLite::TYPE_SINT32:
169 case WireFormatLite::TYPE_SFIXED32:
170 return InitDefaultValue(dtype, field_desc->default_value_int32(), result);
171 case WireFormatLite::TYPE_FIXED32:
172 case WireFormatLite::TYPE_UINT32:
173 return InitDefaultValue(dtype, field_desc->default_value_uint32(),
174 result);
175 case WireFormatLite::TYPE_BOOL:
176 return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
177 case WireFormatLite::TYPE_BYTES:
178 case WireFormatLite::TYPE_STRING:
179 // Manipulating default string values as C-style pointers should be OK
180 // for typical code-generated protocol messages. It is possible in
181 // principle to register a message descriptor on the fly, and these
182 // pointers may not be stable if that descriptor has a weird
183 // implementation. (But the return type of default_value_string() is
184 // const string&, so it'd have to be very weird.)
185 return InitDefaultValue(dtype, field_desc->default_value_string().c_str(),
186 result);
187 case WireFormatLite::TYPE_GROUP:
188 case WireFormatLite::TYPE_MESSAGE:
189 return InitDefaultValue(dtype, "", result);
190 // default: intentionally omitted in order to enable static checking.
191 }
192 return Status::OK();
193 }
194
195 // A FieldInfo holds a handful of information from the FieldDescriptor
196 // and user attributes.
197 struct FieldInfo {
FieldInfotensorflow::__anon8bcbf5420111::FieldInfo198 FieldInfo(const FieldDescriptor* field_desc, int user_index,
199 DefaultValue def_value)
200 : output_index(user_index), default_value(def_value) {
201 // Without this intermediate data structure, the profile had hotspots
202 // calling methods of FieldDescriptor.
203 number = field_desc->number();
204
205 // The wire format library defines the same constants used in
206 // descriptor.proto. This static_cast is safe because they are guaranteed to
207 // stay in sync. We need the field type from the FieldDescriptor here
208 // because the wire format doesn't tell us anything about what happens
209 // inside a packed repeated field: there is enough information in the wire
210 // format to skip the whole field but not enough to know how to parse what's
211 // inside. For that we go to the schema.
212 type = static_cast<WireFormatLite::FieldType>(field_desc->type());
213 is_repeated = field_desc->is_repeated();
214 }
215
216 // Disable copy and move.
217 FieldInfo(const FieldInfo&) = delete;
218 FieldInfo& operator=(const FieldInfo&) = delete;
219
220 // Internally we sort field descriptors by wire number for fast lookup. In
221 // general this is different from the order given by the user. Output_index
222 // gives the index into the field_names and output_types attributes and into
223 // the output tensor list.
224 int output_index = -1;
225
226 // This is a cache of the relevant fields from `FieldDescriptorProto`. This
227 // was added after noticing that FieldDescriptor->type() was using 6% of the
228 // cpu profile.
229 WireFormatLite::FieldType type;
230 int number;
231 bool is_repeated;
232 DefaultValue default_value;
233 };
234
235 // A CountCollector counts sizes of repeated and optional fields in a proto.
236 //
237 // Each field is tracked by a single CountCollector instance. The instance
238 // manages a single count, which is stored as a pointer (it is intended to be a
239 // reference to the `sizes` output which is being filled in). The pointer is
240 // passed in at initialization.
241 //
242 // Counting is done as a separate pass in order to allocate output tensors all
243 // at once. This allows the TensorFlow runtime to optimize allocation for the
244 // consumer, while removing the need for copying inside this op. After this
245 // pass, the DenseCollector class (below) gathers the data: it is more complex
246 // and provides better motivation for the API here.
247 class CountCollector {
248 public:
249 CountCollector() = delete;
250
251 // The count may be stored inside an Eigen Tensor to eliminate copying.
CountCollector(int32 * count)252 explicit CountCollector(int32* count) : count_ptr_(count) {}
253
254 // Reads (in this case counts) a single value.
ReadValue(CodedInputStream * input,const FieldInfo & field)255 Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
256 // Only repeated fields can have count > 1.
257 if (*count_ptr_ == 0 || field.is_repeated) {
258 (*count_ptr_)++;
259 }
260 // We expect a wire type based on the schema field_type, to allow a little
261 // more checking.
262 if (!SkipValue(input, field)) {
263 return errors::DataLoss("ReadValue: Failed skipping field when counting");
264 }
265 return Status::OK();
266 }
267
268 // Reads (in this case counts) a length-delimited list of values.
ReadPackedValues(CodedInputStream * input,const FieldInfo & field,size_t buf_size)269 Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
270 size_t buf_size) {
271 if (buf_size == 0) {
272 return Status::OK();
273 }
274
275 const void* tmpbuf;
276 int unused_max_buf_size;
277
278 input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size);
279 // This is safe because the underlying storage for the CodedInputStream is
280 // owned by the input tensor. If it were a Cord or file-backed stream this
281 // pointer would go stale after the bytes were skipped.
282 const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf);
283
284 // Important: we skipped the input->{Push,Pop}Limit() calls for speed,
285 // so the bounds check on buf_size inside Skip() is critical, and
286 // must be done before scanning the contents.
287 if (!input->Skip(buf_size)) {
288 return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
289 }
290
291 // Dispatch to the appropriately typed field reader based on the schema
292 // type.
293 Status st;
294 switch (field.type) {
295 case WireFormatLite::TYPE_DOUBLE:
296 st = CountPackedFixed<double>(buf, buf_size);
297 break;
298 case WireFormatLite::TYPE_FLOAT:
299 st = CountPackedFixed<float>(buf, buf_size);
300 break;
301 case WireFormatLite::TYPE_INT64:
302 st = CountPackedVarint(buf, buf_size);
303 break;
304 case WireFormatLite::TYPE_UINT64:
305 st = CountPackedVarint(buf, buf_size);
306 break;
307 case WireFormatLite::TYPE_INT32:
308 st = CountPackedVarint(buf, buf_size);
309 break;
310 case WireFormatLite::TYPE_FIXED64:
311 st = CountPackedFixed<uint64>(buf, buf_size);
312 break;
313 case WireFormatLite::TYPE_FIXED32:
314 st = CountPackedFixed<uint32>(buf, buf_size);
315 break;
316 case WireFormatLite::TYPE_BOOL:
317 st = CountPackedVarint(buf, buf_size);
318 break;
319 case WireFormatLite::TYPE_STRING:
320 st = errors::DataLoss("TYPE_STRING encountered as packed");
321 break;
322 case WireFormatLite::TYPE_GROUP:
323 st = errors::DataLoss("TYPE_GROUP encountered as packed");
324 break;
325 case WireFormatLite::TYPE_MESSAGE:
326 st = errors::DataLoss("TYPE_MESSAGE encountered as packed");
327 break;
328 case WireFormatLite::TYPE_BYTES:
329 st = errors::DataLoss("TYPE_BYTES encountered as packed");
330 break;
331 case WireFormatLite::TYPE_UINT32:
332 st = CountPackedVarint(buf, buf_size);
333 break;
334 case WireFormatLite::TYPE_ENUM:
335 st = CountPackedVarint(buf, buf_size);
336 break;
337 case WireFormatLite::TYPE_SFIXED32:
338 st = CountPackedFixed<int32>(buf, buf_size);
339 break;
340 case WireFormatLite::TYPE_SFIXED64:
341 st = CountPackedFixed<int64>(buf, buf_size);
342 break;
343 case WireFormatLite::TYPE_SINT32:
344 st = CountPackedVarint(buf, buf_size);
345 break;
346 case WireFormatLite::TYPE_SINT64:
347 st = CountPackedVarint(buf, buf_size);
348 break;
349 // default: intentionally omitted in order to enable static checking.
350 }
351 if (!st.ok()) {
352 return st;
353 }
354
355 if (!field.is_repeated && *count_ptr_ > 1) {
356 *count_ptr_ = 1;
357 }
358 return Status::OK();
359 }
360
361 private:
362 // Skips a length-delimited value.
SkipBytes(CodedInputStream * input)363 static bool SkipBytes(CodedInputStream* input) {
364 uint32 length;
365 if (!input->ReadVarint32(&length)) {
366 return false;
367 }
368 return input->Skip(length);
369 }
370
371 // Counts the number of packed varints in an array. The end of a varint is
372 // signaled by a value < 0x80, so counting them requires parsing the
373 // bytestream. It is the caller's responsibility to ensure that len > 0.
CountPackedVarint(const uint8 * buf,size_t len)374 Status CountPackedVarint(const uint8* buf, size_t len) {
375 const uint8* bound = buf + len;
376 int count;
377
378 // The last byte in a valid encoded varint is guaranteed to have the high
379 // bit unset. We rely on this property to prevent ReadVarint64FromArray from
380 // going out of bounds, so validate the end of the buf before scanning
381 // anything.
382 if (bound[-1] & 0x80) {
383 return errors::DataLoss("Corrupt packed varint");
384 }
385
386 // Now we can trust ReadVarint64FromArray to stay in bounds.
387 for (count = 0; buf < bound; ++count) {
388 uint64 temp;
389 bool ok;
390 buf = internal::ReadVarint64FromArray(buf, &ok, &temp);
391 if (!ok) {
392 return errors::DataLoss("Corrupt packed varint");
393 }
394 }
395
396 *count_ptr_ += count;
397 return Status::OK();
398 }
399
400 // Counts the number of fixed-size values in a packed field. This can be done
401 // without actually parsing anything.
402 template <typename T>
CountPackedFixed(const uint8 * unused_buf,size_t len)403 Status CountPackedFixed(const uint8* unused_buf, size_t len) {
404 int count = len / sizeof(T);
405 if (count * sizeof(T) != len) {
406 return errors::DataLoss(
407 "Illegal data length for packed fixed-size type: ", len);
408 }
409 *count_ptr_ += len / sizeof(T);
410 return Status::OK();
411 }
412
413 // Skips a single value in the input stream. Dispatches to the appropriately
414 // typed field skipper based on the schema type tag. This is not as permissive
415 // as just handling the wire type.
SkipValue(CodedInputStream * input,const FieldInfo & field)416 static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
417 uint32 tmp32;
418 protobuf_uint64 tmp64;
419 switch (field.type) {
420 case WireFormatLite::TYPE_DOUBLE:
421 return input->ReadLittleEndian64(&tmp64);
422 case WireFormatLite::TYPE_FLOAT:
423 return input->ReadLittleEndian32(&tmp32);
424 case WireFormatLite::TYPE_INT64:
425 return input->ReadVarint64(&tmp64);
426 case WireFormatLite::TYPE_UINT64:
427 return input->ReadVarint64(&tmp64);
428 case WireFormatLite::TYPE_INT32:
429 return input->ReadVarint32(&tmp32);
430 case WireFormatLite::TYPE_FIXED64:
431 return input->ReadLittleEndian64(&tmp64);
432 case WireFormatLite::TYPE_FIXED32:
433 return input->ReadLittleEndian32(&tmp32);
434 case WireFormatLite::TYPE_BOOL:
435 return input->ReadVarint32(&tmp32);
436 case WireFormatLite::TYPE_STRING:
437 return SkipBytes(input);
438 case WireFormatLite::TYPE_GROUP:
439 return WireFormatLite::SkipField(
440 input, WireFormatLite::MakeTag(
441 field.number, WireFormatLite::WIRETYPE_START_GROUP));
442 case WireFormatLite::TYPE_MESSAGE:
443 return SkipBytes(input);
444 case WireFormatLite::TYPE_BYTES:
445 return SkipBytes(input);
446 case WireFormatLite::TYPE_UINT32:
447 return input->ReadVarint32(&tmp32);
448 case WireFormatLite::TYPE_ENUM:
449 return input->ReadVarint32(&tmp32);
450 case WireFormatLite::TYPE_SFIXED32:
451 return input->ReadLittleEndian32(&tmp32);
452 case WireFormatLite::TYPE_SFIXED64:
453 return input->ReadLittleEndian64(&tmp64);
454 case WireFormatLite::TYPE_SINT32:
455 return input->ReadVarint32(&tmp32);
456 case WireFormatLite::TYPE_SINT64:
457 return input->ReadVarint64(&tmp64);
458 // default: intentionally omitted in order to enable static checking.
459 }
460 }
461
462 int32* count_ptr_ = nullptr;
463 };
464
465 // A DenseCollector accumulates values from a proto into a tensor.
466 //
467 // There is an instance of DenseCollector for each field of each proto. The
468 // DenseCollector deserializes the value from the wire directly into the
469 // preallocated output Tensor.
470 //
471 // This class is named DenseCollector because in the future there should be a
472 // SparseCollector that accumulates field data into sparse tensors if the user
473 // requests it.
474 class DenseCollector {
475 public:
476 DenseCollector() = delete;
477
478 // A DenseCollector applies to one field of a serialized message.
479 // Note that default_value.dtype is the type of the output tensor.
DenseCollector(uint8 * datap,DefaultValue default_value,int max_repeat_count)480 DenseCollector(uint8* datap, DefaultValue default_value, int max_repeat_count)
481 : datap_(datap),
482 default_value_(default_value),
483 max_repeat_count_(max_repeat_count) {}
484
485 // Reads a value from the input stream and stores it.
486 //
487 // Always inlining gave a ~50% speedup on microbenchmarks at one point.
488 // TODO(nix): try removing it to see if that still holds.
489 // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE
ReadValue(CodedInputStream * input,const FieldInfo & field)490 Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
491 // For required and optional fields, we overwrite values[0] with
492 // the latest one in the wire stream.
493 // See https://developers.google.com/protocol-buffers/docs/encoding#optional
494 // Only for repeated fields do we advance the next_repeat_index_ past 1.
495 // TODO(nix): to handle oneof we must also zero out any previous values
496 // seen on the wire.
497 int32 index = 0;
498 if (field.is_repeated) {
499 index = next_repeat_index_;
500 }
501 next_repeat_index_ = index + 1;
502
503 return internal::ReadValue(input, field.type, field.number,
504 default_value_.dtype, index, datap_);
505 }
506
507 // Reads and stores a length-delimited list of values.
ReadPackedValues(CodedInputStream * input,const FieldInfo & field,const size_t buf_size)508 Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
509 const size_t buf_size) {
510 const void* buf;
511 int unused_max_buf_size;
512 input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size);
513 // This is safe because the underlying storage for the CodedInputStream is
514 // owned by the input tensor. If it were a Cord or file-backed stream this
515 // pointer would go stale after the bytes were skipped.
516 if (!input->Skip(buf_size)) {
517 return errors::DataLoss(
518 "ReadPackedValues: Skipping packed field failed. Field tag: ",
519 field.number);
520 }
521
522 // Setting stride=0 causes new values to overwrite old ones for
523 // non-repeated fields.
524 const int stride = field.is_repeated ? 1 : 0;
525
526 if (next_repeat_index_ >= max_repeat_count_) {
527 return errors::DataLoss(
528 "ReadPackedValues: Tried to write more entries than allowed. "
529 "Field tag: ",
530 field.number, ", Max entries allowed: ", max_repeat_count_);
531 } else {
532 return internal::ReadPackedFromArray(buf, buf_size, field.type,
533 field.number, default_value_.dtype,
534 stride, &next_repeat_index_, datap_);
535 }
536 }
537
538 // Fills in any missing values in the output array with defaults. Dispatches
539 // to the appropriately typed field default based on the runtime type tag.
FillWithDefaults()540 Status FillWithDefaults() {
541 switch (default_value_.dtype) {
542 case DataType::DT_BOOL:
543 return FillDefault<bool>(default_value_.value.v_bool);
544 case DataType::DT_FLOAT:
545 return FillDefault<float>(default_value_.value.v_float);
546 case DataType::DT_DOUBLE:
547 return FillDefault<double>(default_value_.value.v_double);
548 case DataType::DT_INT8:
549 return FillDefault<int8>(default_value_.value.v_int8);
550 case DataType::DT_INT32:
551 return FillDefault<int32>(default_value_.value.v_int32);
552 case DataType::DT_INT64:
553 return FillDefault<int64>(default_value_.value.v_int64);
554 case DataType::DT_STRING:
555 return FillDefault<string>(default_value_.value.v_string);
556 case DataType::DT_UINT8:
557 return FillDefault<uint8>(default_value_.value.v_uint8);
558 case DataType::DT_UINT32:
559 return FillDefault<uint32>(default_value_.value.v_uint32);
560 case DataType::DT_UINT64:
561 return FillDefault<uint64>(default_value_.value.v_uint64);
562 default:
563 // There are many tensorflow dtypes not handled here, but they
564 // should not come up unless type casting is added to the Op.
565 // Chaining with tf.cast() should do the right thing until then.
566 return errors::DataLoss("Failed filling defaults for ",
567 DataTypeString(default_value_.dtype));
568 }
569 }
570
571 private:
572 // Fills empty values in the dense representation with a default value. This
573 // uses next_repeat_index_ which counts the number of parsed values for the
574 // field.
575 template <class T>
FillDefault(const T & default_value)576 Status FillDefault(const T& default_value) {
577 for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
578 reinterpret_cast<T*>(datap_)[i] = default_value;
579 }
580 return Status::OK();
581 }
582
583 int32 next_repeat_index_ = 0;
584
585 // This is a pointer to data_[message_index_]. There is no bounds checking at
586 // this level: we computed the max repeat size for each field in
587 // CountCollector and use the same code to traverse it here, so we are
588 // guaranteed not to be called for more items than we have allocated space.
589 void* const datap_ = nullptr;
590
591 const DefaultValue default_value_;
592 const int max_repeat_count_ = 0;
593 };
594
595 class DecodeProtoOp : public OpKernel {
596 public:
DecodeProtoOp(OpKernelConstruction * context)597 explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
598 string descriptor_source;
599 OP_REQUIRES_OK(context,
600 context->GetAttr("descriptor_source", &descriptor_source));
601
602 // We always get back a desc_pool, but we may not own it. If we own it,
603 // owned_desc_pool_ will be filled in.
604 DescriptorPool const* desc_pool;
605 OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
606 &desc_pool, &owned_desc_pool_));
607
608 string message_type;
609 OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
610
611 const Descriptor* message_desc =
612 desc_pool->FindMessageTypeByName(message_type);
613 OP_REQUIRES(context, message_desc != nullptr,
614 errors::InvalidArgument("No descriptor found for message type ",
615 message_type));
616
617 std::vector<string> field_names;
618 OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names));
619 std::vector<DataType> output_types;
620 OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types));
621 OP_REQUIRES(
622 context, field_names.size() == output_types.size(),
623 errors::InvalidArgument("field_names and output_types attributes must "
624 "have the same length"));
625
626 // Gather the field descriptors and check that requested output types match.
627 int field_index = 0;
628 std::vector<const FieldDescriptor*> field_descs;
629 std::vector<const FieldDescriptor*> exts;
630 absl::flat_hash_map<string, const FieldDescriptor*> ext_name_to_field;
631 std::vector<const FieldDescriptor*>::iterator ext_it = exts.begin();
632 for (const string& name : field_names) {
633 auto fd = message_desc->FindFieldByName(name);
634 if (fd == nullptr) {
635 // If field can't be found in original message, try to find a matching
636 // extension (by its full_name). First check a hashmap for a matching
637 // extension, and if not found, then iterate through available
638 // extensions to find a match (updating the hashmap while iterating.)
639 auto lookup_result = ext_name_to_field.find(name);
640 if (lookup_result != ext_name_to_field.end()) {
641 fd = lookup_result->second;
642 } else {
643 if (ext_it == exts.begin()) {
644 desc_pool->FindAllExtensions(message_desc, &exts);
645 ext_it = exts.begin();
646 }
647 while (ext_it != exts.end()) {
648 auto ext_name = (*ext_it)->full_name();
649 auto ext_field = *ext_it;
650 ++ext_it;
651
652 ext_name_to_field.insert({ext_name, ext_field});
653 if (ext_name == name) {
654 fd = ext_field;
655 break;
656 }
657 }
658 }
659 }
660 OP_REQUIRES(context, fd != nullptr,
661 errors::InvalidArgument("Unknown field: ", name,
662 " in message type ", message_type));
663 OP_REQUIRES(
664 context,
665 proto_utils::IsCompatibleType(fd->type(), output_types[field_index]),
666 // Many TensorFlow types don't have corresponding proto types and the
667 // user will get an error if they are requested. It would be nice to
668 // allow conversions here, but tf.cast already exists so we don't
669 // duplicate the functionality.
670 errors::InvalidArgument("Unexpected output type for ",
671 fd->full_name(), ": ", fd->cpp_type(), " to ",
672 output_types[field_index]));
673
674 field_index++;
675 field_descs.push_back(fd);
676 }
677
678 // Internally we want the field_descs sorted by their number on the wire.
679 // But the output tensors are allocated in the order given by the caller.
680 // Build a mapping i->j, where field_descs[i] corresponds to outputs[j].
681 std::vector<int> output_indices;
682 output_indices.reserve(field_names.size());
683 for (int i = 0; i < field_names.size(); i++) {
684 output_indices.push_back(i);
685 }
686 std::sort(output_indices.begin(), output_indices.end(),
687 [field_descs](int a, int b) {
688 return field_descs[a]->number() < field_descs[b]->number();
689 });
690
691 // Now store the fields in sorted order.
692 for (int i = 0; i < field_names.size(); i++) {
693 const int output_index = output_indices[i];
694 const DataType dtype = output_types[output_index];
695 const FieldDescriptor* field_descriptor = field_descs[output_index];
696 DefaultValue default_value;
697 OP_REQUIRES_OK(context, InitDefaultValueFromFieldDescriptor(
698 dtype, field_descriptor, &default_value));
699 fields_.push_back(
700 MakeUnique<FieldInfo>(field_descriptor, output_index, default_value));
701 }
702
703 message_prototype_ = message_factory_.GetPrototype(message_desc);
704 OP_REQUIRES(context, message_prototype_ != nullptr,
705 errors::InvalidArgument("Couldn't get prototype message: ",
706 message_desc->full_name()));
707 string format;
708 OP_REQUIRES_OK(context, context->GetAttr("message_format", &format));
709 OP_REQUIRES(
710 context, format == "binary" || format == "text",
711 errors::InvalidArgument("format must be one of binary or text"));
712 is_binary_ = format == "binary";
713
714 // Enable the initial protobuf sanitizer, which is much more expensive than
715 // the decoder.
716 // TODO(nix): Remove this once the fast decoder has passed security review.
717 OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
718 }
719
Compute(OpKernelContext * ctx)720 void Compute(OpKernelContext* ctx) override {
721 const Tensor& buf_tensor = ctx->input(0);
722 int message_count = buf_tensor.NumElements();
723 OP_REQUIRES(ctx, message_count >= 1,
724 errors::InvalidArgument(
725 "Bufs argument must contain at least one value"));
726
727 int field_count = fields_.size();
728
729 // Save the argument shape for later, then flatten the input Tensor since we
730 // are working componentwise. We will restore the same shape in the returned
731 // Tensor.
732 const TensorShape& shape_prefix = buf_tensor.shape();
733
734 TensorShape sizes_shape = shape_prefix;
735 sizes_shape.AddDim(field_count);
736 Tensor* sizes_tensor = nullptr;
737 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
738
739 // This is used to allocate binary bufs if used. It serves only to define
740 // memory ownership.
741 std::vector<string> tmp_binary_bufs(message_count);
742
743 // These are the actual buffers to use, which may be in tmp_binary_bufs
744 // or may be pointers into the buf_tensor. Either way they are not owned
745 // here.
746 std::vector<const string*> bufs;
747
748 if (is_binary_ && !sanitize_) {
749 // Fast path.
750 for (int mi = 0; mi < message_count; ++mi) {
751 const string* buf = &buf_tensor.flat<string>()(mi);
752 bufs.push_back(buf);
753 }
754 } else {
755 // We will have to allocate a copy, either to convert from text to binary
756 // or to sanitize a binary proto.
757 for (int mi = 0; mi < message_count; ++mi) {
758 ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
759 &tmp_binary_bufs[mi]);
760 if (!ctx->status().ok()) {
761 return;
762 }
763 bufs.push_back(&tmp_binary_bufs[mi]);
764 }
765 }
766
767 // Walk through all the strings in the input tensor, counting the number of
768 // fields in each. We can't allocate our actual output Tensor until we know
769 // the maximum repeat count, so we do a first pass through the serialized
770 // proto just counting fields. We always allocate at least one value so that
771 // optional fields are populated with default values - this avoids a TF
772 // conditional when handling the output data. The caller can distinguish
773 // between real data and defaults using the repeat count matrix that is
774 // returned by decode_proto.
775 std::vector<int32> max_sizes(field_count, 1);
776 for (int mi = 0; mi < message_count; ++mi) {
777 CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
778 if (!ctx->status().ok()) {
779 return;
780 }
781 }
782
783 // Allocate the output tensors now that we've seen the max size.
784 // TODO(nix): Use allocate_output_or_forward_input for the largest
785 // output tensor. This can avoid one large allocation by re-using
786 // the memory of the input tensor.
787 std::vector<Tensor*> outputs(field_count);
788 for (int fi = 0; fi < field_count; ++fi) {
789 TensorShape flat_shape = {static_cast<int64>(message_count),
790 max_sizes[fi]};
791 TensorShape out_shape = shape_prefix;
792 out_shape.AddDim(max_sizes[fi]);
793
794 // Surprisingly we don't specify the types from the output_types
795 // attribute: that is done for us based on the Op declaration:
796 // REGISTER_OP(...)
797 // .Attr("output_types: list(type) >= 0")
798 // .Output("values: output_types")
799 OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1,
800 out_shape, &outputs[fi]));
801 }
802
803 // Make the second pass through the serialized proto, decoding into
804 // preallocated tensors.
805 AccumulateFields(ctx, bufs, outputs);
806 }
807
808 private:
809 // Copy a serialized message to binary, e.g. to handle text proto inputs.
ReserializeMessage(OpKernelContext * ctx,const string & buf,string * binary_buf)810 void ReserializeMessage(OpKernelContext* ctx, const string& buf,
811 string* binary_buf) {
812 // Handle text protos by translating them to binary.
813 std::unique_ptr<Message> message(message_prototype_->New());
814 OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed"));
815
816 if (is_binary_) {
817 // If we get here we are sanitizing the input protobuf by parsing
818 // and reserializing it with a trusted (but very slow) library.
819 OP_REQUIRES(ctx, message->ParseFromString(buf),
820 errors::DataLoss("Unable to parse binary protobuf"));
821 } else {
822 OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()),
823 errors::DataLoss("Unable to parse text protobuf"));
824 }
825
826 OP_REQUIRES(ctx, message->SerializeToString(binary_buf),
827 errors::DataLoss("Unable to reserialize text proto as binary"));
828 }
829
830 // Count the number of occurrences of each requested field in a message batch.
CountFields(OpKernelContext * ctx,int message_index,const string & buf,Tensor * sizes_tensor,std::vector<int32> * max_sizes)831 void CountFields(OpKernelContext* ctx, int message_index, const string& buf,
832 Tensor* sizes_tensor, std::vector<int32>* max_sizes) {
833 int field_count = fields_.size();
834
835 CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
836 buf.size());
837
838 std::vector<int32> field_sizes(field_count, 0);
839 std::vector<CountCollector> counters;
840 counters.reserve(field_count);
841 for (int i = 0; i < field_count; i++) {
842 counters.emplace_back(&field_sizes[i]);
843 }
844
845 Status st = Collect(&input, &counters);
846 if (st.ok() && !input.ConsumedEntireMessage()) {
847 st = errors::DataLoss("CountFields: Failed to consume entire buffer");
848 }
849 if (kFailOnDecodeError) {
850 OP_REQUIRES_OK(ctx, st); // NOLINT
851 }
852 if (!st.ok()) {
853 // This code suppresses the corrupt proto, treating it as empty
854 // to avoid crashing the process.
855 LOG(WARNING) << "Proto counting error for message type " << message_type_
856 << ": " << st;
857
858 for (int fi = 0; fi < field_count; fi++) {
859 field_sizes[fi] = 0;
860 }
861 // Finished decoding this message.
862 return;
863 }
864
865 // Update the size tensor and max repeat size for each field.
866 auto sizes = sizes_tensor->flat_inner_dims<int32>();
867 for (int fi = 0; fi < field_count; fi++) {
868 int32 size = field_sizes[fi];
869 sizes(message_index, fields_[fi]->output_index) = size;
870 if ((*max_sizes)[fi] < size) {
871 (*max_sizes)[fi] = size;
872 }
873 }
874 }
875
876 // Parse fields from a serialized message into preallocated tensors.
AccumulateFields(OpKernelContext * ctx,const std::vector<const string * > & bufs,std::vector<Tensor * > outputs)877 void AccumulateFields(OpKernelContext* ctx,
878 const std::vector<const string*>& bufs,
879 std::vector<Tensor*> outputs) {
880 struct TensorInfo {
881 explicit TensorInfo(Tensor* tensor) {
882 // Note that we can decode only max_repeat_count values before overflow.
883 // No other bounds checking is done for repeated fields. For
884 // optional fields there is a check to make sure that only the last
885 // value on the wire appears in the output tensor.
886 dtype = tensor->dtype();
887 last_dim_size = tensor->dim_size(tensor->dims() - 1);
888
889 if (dtype != DT_STRING) {
890 const int element_size = DataTypeSize(dtype);
891 CHECK_GT(element_size, 0);
892 stride = last_dim_size * element_size;
893
894 const int64 flatshape[1] = {tensor->NumElements() * element_size};
895 data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
896 } else {
897 // DataTypeSize() returns 0 for string types.
898 stride = last_dim_size * sizeof(string);
899 data = reinterpret_cast<uint8*>(tensor->flat<string>().data());
900 }
901 }
902
903 DataType dtype;
904 int last_dim_size;
905 int stride;
906 uint8* data;
907 };
908
909 int field_count = fields_.size();
910
911 std::vector<TensorInfo> tensors;
912 tensors.reserve(field_count);
913 for (int fi = 0; fi < field_count; fi++) {
914 tensors.emplace_back(outputs[fi]);
915 }
916
917 for (int message_index = 0; message_index < bufs.size(); ++message_index) {
918 const string& buf = *bufs[message_index];
919
920 std::vector<DenseCollector> collectors;
921 collectors.reserve(field_count);
922 for (int output_index = 0; output_index < field_count; ++output_index) {
923 const TensorInfo& info = tensors[output_index];
924 const FieldInfo* field_info = fields_[output_index].get();
925 DCHECK(field_info != nullptr);
926 const DefaultValue default_value = field_info->default_value;
927 collectors.emplace_back(info.data + message_index * info.stride,
928 default_value, info.last_dim_size);
929 }
930
931 // Fill in output tensors from the wire.
932 CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
933 buf.size());
934 Status st = Collect(&input, &collectors);
935 if (st.ok() && !input.ConsumedEntireMessage()) {
936 st = errors::DataLoss(
937 "AccumulateFields: Failed to consume entire buffer");
938 }
939 if (kFailOnDecodeError) {
940 OP_REQUIRES_OK(ctx, st); // NOLINT
941 }
942 if (!st.ok()) {
943 // This code suppresses the corrupt proto, treating it as empty
944 // to avoid crashing training.
945 LOG(WARNING) << "Proto counting error for message type "
946 << message_type_ << ": " << st;
947 }
948
949 // Fill the remainder of the dense outputs with default values.
950 for (auto& collector : collectors) {
951 OP_REQUIRES_OK(ctx, collector.FillWithDefaults());
952 }
953 }
954 }
955
956 // Look up the FieldDescriptor for a particular field number.
LookupField(int field_number,int * field_index)957 bool LookupField(int field_number, int* field_index) {
958 // Look up the FieldDescriptor using linear search.
959 //
960 // TODO(nix): this could be sped up with binary search, but we are
961 // already way off the fastpath at this point. If you see a hotspot
962 // here, somebody is sending you very inefficient protos.
963 for (int fi = fields_.size() - 1; fi >= 0; fi--) {
964 if (field_number == fields_[fi]->number) {
965 *field_index = fi;
966 return true;
967 }
968 }
969 return false;
970 }
971
972 // Traverses a serialized protobuf, dispatching values to the collectors.
973 template <class CollectorClass>
Collect(CodedInputStream * input,std::vector<CollectorClass> * collectors)974 Status Collect(CodedInputStream* input,
975 std::vector<CollectorClass>* collectors) {
976 int last_good_field_index = -1;
977 bool fields_disordered = false;
978 int prev_field_number = -1;
979 int field_number = -1;
980 int last_good_field_number = -1;
981 int next_good_field_number = fields_[0]->number;
982
983 // The 'tag' variable should always be treated as tainted.
984 for (uint32 tag = input->ReadTag();
985 tag != 0 && WireFormatLite::GetTagWireType(tag) !=
986 WireFormatLite::WIRETYPE_END_GROUP;
987 tag = input->ReadTag(), prev_field_number = field_number) {
988 field_number = WireFormatLite::GetTagFieldNumber(tag);
989 const FieldInfo* field = nullptr;
990
991 // This takes advantage of the sorted field numbers in most serialized
992 // protos: it tries the next expected field first rather than doing
993 // a lookup by field number.
994 //
995 // TODO(nix): haberman@ suggests a hybrid approach with a lookup table
996 // for small field numbers and a hash table for larger ones. This would
997 // be a simpler approach that should offer comparable speed in most
998 // cases.
999 if (field_number == last_good_field_number) {
1000 field = fields_[last_good_field_index].get();
1001 } else {
1002 if (field_number < prev_field_number) {
1003 fields_disordered = true;
1004 }
1005
1006 // If fields are out of order, fall back to slow lookup.
1007 if (fields_disordered) {
1008 int field_index;
1009 if (LookupField(field_number, &field_index)) {
1010 field = fields_[field_index].get();
1011 last_good_field_index = field_index;
1012 }
1013 } else {
1014 // If we see a field that is past the next field we want, it was
1015 // empty. Look for the one after that. Repeat until we run out of
1016 // fields that we care about.
1017 while (field_number >= next_good_field_number) {
1018 if (field_number == next_good_field_number) {
1019 last_good_field_number = field_number;
1020 field = fields_[last_good_field_index + 1].get();
1021 }
1022
1023 // Start looking for the field after the current one.
1024 ++last_good_field_index;
1025 if (last_good_field_index < fields_.size() - 1) {
1026 next_good_field_number =
1027 fields_[last_good_field_index + 1]->number;
1028 } else {
1029 // Saw something past the last field we care about. Continue
1030 // parsing the message just in case there are disordered fields
1031 // later, but any remaining ordered fields will have no effect.
1032 next_good_field_number = INT_MAX;
1033 }
1034 }
1035 }
1036 }
1037
1038 if (!field) {
1039 // Unknown and unrequested fields are skipped.
1040 if (!WireFormatLite::SkipField(input, tag)) {
1041 return errors::DataLoss("Failed skipping unrequested field");
1042 }
1043 continue;
1044 }
1045
1046 Status st = CollectField(*field, WireFormatLite::GetTagWireType(tag),
1047 input, &(*collectors)[last_good_field_index]);
1048 if (!st.ok()) {
1049 return st;
1050 }
1051 }
1052 return Status::OK();
1053 }
1054
1055 // Collects values for a single field.
1056 template <class CollectorClass>
CollectField(const FieldInfo & field,WireFormatLite::WireType wire_type,CodedInputStream * input,CollectorClass * collector)1057 Status CollectField(const FieldInfo& field,
1058 WireFormatLite::WireType wire_type,
1059 CodedInputStream* input, CollectorClass* collector) {
1060 // The wire format library defines the same constants used in
1061 // descriptor.proto. This static_cast is safe because they are guaranteed to
1062 // stay in sync.
1063 //
1064 // We need the field type from the FieldDescriptor here because the wire
1065 // format doesn't tell us anything about what happens inside a packed
1066 // repeated field: there is enough information in the wire format to skip
1067 // the whole field but not enough to know how to parse what's inside. For
1068 // that we go to the schema.
1069 WireFormatLite::WireType schema_wire_type =
1070 WireFormatLite::WireTypeForFieldType(field.type);
1071
1072 // Handle packed repeated fields. SkipField would skip the whole
1073 // length-delimited blob without letting us count the values, so we have to
1074 // scan them ourselves.
1075 if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
1076 schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
1077 // Handle packed repeated primitives.
1078 int length;
1079 if (!input->ReadVarintSizeAsInt(&length)) {
1080 return errors::DataLoss("CollectField: Failed reading packed size");
1081 }
1082 return collector->ReadPackedValues(input, field, length);
1083 }
1084
1085 // Read ordinary values, including strings, bytes, and messages.
1086 if (wire_type != schema_wire_type) {
1087 if (!WireFormatLite::SkipField(
1088 input, WireFormatLite::MakeTag(field.number, wire_type))) {
1089 return errors::DataLoss(
1090 "CollectField: Failed skipping malformed field");
1091 }
1092 return Status::OK();
1093 }
1094 return collector->ReadValue(input, field);
1095 }
1096
1097 string message_type_;
1098 // Note that fields are sorted by increasing field number, which is not in
1099 // general the order given by the user-specified field_names and output_types
1100 // Op attributes.
1101 std::vector<std::unique_ptr<const FieldInfo>> fields_;
1102
1103 // Owned_desc_pool_ is null when using descriptor_source=local.
1104 std::unique_ptr<DescriptorPool> owned_desc_pool_;
1105 DynamicMessageFactory message_factory_;
1106 const Message* message_prototype_;
1107
1108 // True if decoding binary format, false if decoding text format.
1109 bool is_binary_;
1110
1111 // True if the protos should be sanitized before parsing. Enables the initial
1112 // protobuf sanitizer, which is much more expensive than the decoder. The flag
1113 // defaults to true but can be set to false for trusted sources.
1114 //
1115 // TODO(nix): Flip the default to false when the fast decoder has passed
1116 // security review.
1117 bool sanitize_;
1118
1119 TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
1120 };
1121
1122 REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU),
1123 DecodeProtoOp);
1124
1125 } // namespace
1126 } // namespace tensorflow
1127