• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // DecodeProto is a TensorFlow op which extracts arbitrary fields from protos
17 // serialized as strings.
18 //
19 // See docs in ../ops/decode_proto_op.cc.
20 //
21 // This implementation reads the serialized format using a handful of calls from
22 // the WireFormatLite API used by generated proto code. WireFormatLite is marked
23 // as an "internal" proto API but is widely used in practice and highly unlikely
24 // to change. This will be much faster than the previous implementation based on
25 // constructing a temporary dynamic message in memory and using the proto
26 // reflection api to read it. It can be used with any proto whose descriptors
27 // are available at runtime but should be competitive in speed with approaches
28 // that compile in the proto definitions.
29 
30 #include <memory>
31 #include <string>
32 #include <vector>
33 
34 #include "absl/container/flat_hash_map.h"
35 #include "third_party/eigen3/Eigen/Core"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/tensor_types.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/util/proto/decode.h"
43 #include "tensorflow/core/util/proto/descriptors.h"
44 #include "tensorflow/core/util/proto/proto_utils.h"
45 #include "tensorflow/core/util/ptr_util.h"
46 
47 namespace tensorflow {
48 namespace {
49 
50 using ::tensorflow::MakeUnique;
51 using ::tensorflow::protobuf::Descriptor;
52 using ::tensorflow::protobuf::DescriptorPool;
53 using ::tensorflow::protobuf::DynamicMessageFactory;
54 using ::tensorflow::protobuf::FieldDescriptor;
55 using ::tensorflow::protobuf::Message;
56 using ::tensorflow::protobuf::TextFormat;
57 using ::tensorflow::protobuf::internal::WireFormatLite;
58 using ::tensorflow::protobuf::io::CodedInputStream;
59 
60 const bool kFailOnDecodeError = true;
61 
62 // Used to store the default value of a protocol message field, casted to the
63 // type of the output tensor.
64 //
65 // TODO(paskin): Use absl::variant once TensorFlow gets absl dependencies.
66 struct DefaultValue {
67   DataType dtype = DataType::DT_INVALID;
68   union Value {
69     bool v_bool;           // DT_BOOL
70     double v_double;       // DT_DOUBLE
71     float v_float;         // DT_FLOAT
72     int8 v_int8;           // DT_INT8
73     int32 v_int32;         // DT_INT32
74     int64 v_int64;         // DT_INT64
75     const char* v_string;  // DT_STRING
76     uint8 v_uint8;         // DT_UINT8
77     uint8 v_uint32;        // DT_UINT32
78     uint8 v_uint64;        // DT_UINT64
79   };
80   Value value;
81 };
82 
83 // Initializes a DefaultValue object.  This generic template handles numeric
84 // types and strings are handled by a template specialization below.
85 //
86 // Args:
87 //   dtype: the type of the output tensor
88 //   value: the default value as obtained from the FieldDescriptor
89 //   result: the object to initialize
90 template <typename T>
InitDefaultValue(DataType dtype,const T value,DefaultValue * result)91 Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) {
92   result->dtype = dtype;
93   switch (dtype) {
94     case DT_BOOL:
95       result->value.v_bool = static_cast<bool>(value);
96       break;
97     case DT_DOUBLE:
98       result->value.v_double = static_cast<double>(value);
99       break;
100     case DT_FLOAT:
101       result->value.v_float = static_cast<float>(value);
102       break;
103     case DT_INT8:
104       result->value.v_int8 = static_cast<int8>(value);
105       break;
106     case DT_INT32:
107       result->value.v_int32 = static_cast<int32>(value);
108       break;
109     case DT_INT64:
110       result->value.v_int64 = static_cast<int64>(value);
111       break;
112     case DT_UINT8:
113       result->value.v_uint8 = static_cast<uint8>(value);
114       break;
115     case DT_UINT32:
116       result->value.v_uint32 = static_cast<uint32>(value);
117       break;
118     case DT_UINT64:
119       result->value.v_uint64 = static_cast<uint64>(value);
120       break;
121     default:
122       // We should never get here, given the type checking that occurs earlier.
123       return errors::Internal(
124           "Cannot initialize default value for unsupported type: ",
125           DataTypeString(dtype));
126   }
127   return Status::OK();
128 }
129 
130 template <>
InitDefaultValue(DataType dtype,const char * value,DefaultValue * result)131 Status InitDefaultValue(DataType dtype, const char* value,
132                         DefaultValue* result) {
133   // These are sanity checks that should never trigger given the code that
134   // leads here.
135   if (TF_PREDICT_FALSE(dtype != DT_STRING)) {
136     return errors::InvalidArgument(
137         "Cannot cast field to anything but DT_STRING");
138   }
139   if (TF_PREDICT_FALSE(value == nullptr)) {
140     return errors::InvalidArgument("Null default string value.");
141   }
142   result->dtype = DT_STRING;
143   result->value.v_string = value;
144   return Status::OK();
145 }
146 
147 // Initializes a default value from the output data type and the field
148 // descriptor.
InitDefaultValueFromFieldDescriptor(DataType dtype,const FieldDescriptor * field_desc,DefaultValue * result)149 Status InitDefaultValueFromFieldDescriptor(DataType dtype,
150                                            const FieldDescriptor* field_desc,
151                                            DefaultValue* result) {
152   switch (field_desc->type()) {
153     case WireFormatLite::TYPE_DOUBLE:
154       return InitDefaultValue(dtype, field_desc->default_value_double(),
155                               result);
156     case WireFormatLite::TYPE_FLOAT:
157       return InitDefaultValue(dtype, field_desc->default_value_float(), result);
158     case WireFormatLite::TYPE_INT64:
159     case WireFormatLite::TYPE_SINT64:
160     case WireFormatLite::TYPE_SFIXED64:
161       return InitDefaultValue(dtype, field_desc->default_value_int64(), result);
162     case WireFormatLite::TYPE_FIXED64:
163     case WireFormatLite::TYPE_UINT64:
164       return InitDefaultValue(dtype, field_desc->default_value_uint64(),
165                               result);
166     case WireFormatLite::TYPE_ENUM:
167     case WireFormatLite::TYPE_INT32:
168     case WireFormatLite::TYPE_SINT32:
169     case WireFormatLite::TYPE_SFIXED32:
170       return InitDefaultValue(dtype, field_desc->default_value_int32(), result);
171     case WireFormatLite::TYPE_FIXED32:
172     case WireFormatLite::TYPE_UINT32:
173       return InitDefaultValue(dtype, field_desc->default_value_uint32(),
174                               result);
175     case WireFormatLite::TYPE_BOOL:
176       return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
177     case WireFormatLite::TYPE_BYTES:
178     case WireFormatLite::TYPE_STRING:
179       // Manipulating default string values as C-style pointers should be OK
180       // for typical code-generated protocol messages.  It is possible in
181       // principle to register a message descriptor on the fly, and these
182       // pointers may not be stable if that descriptor has a weird
183       // implementation.  (But the return type of default_value_string() is
184       // const string&, so it'd have to be very weird.)
185       return InitDefaultValue(dtype, field_desc->default_value_string().c_str(),
186                               result);
187     case WireFormatLite::TYPE_GROUP:
188     case WireFormatLite::TYPE_MESSAGE:
189       return InitDefaultValue(dtype, "", result);
190       // default: intentionally omitted in order to enable static checking.
191   }
192   return Status::OK();
193 }
194 
195 // A FieldInfo holds a handful of information from the FieldDescriptor
196 // and user attributes.
197 struct FieldInfo {
FieldInfotensorflow::__anon8bcbf5420111::FieldInfo198   FieldInfo(const FieldDescriptor* field_desc, int user_index,
199             DefaultValue def_value)
200       : output_index(user_index), default_value(def_value) {
201     // Without this intermediate data structure, the profile had hotspots
202     // calling methods of FieldDescriptor.
203     number = field_desc->number();
204 
205     // The wire format library defines the same constants used in
206     // descriptor.proto. This static_cast is safe because they are guaranteed to
207     // stay in sync. We need the field type from the FieldDescriptor here
208     // because the wire format doesn't tell us anything about what happens
209     // inside a packed repeated field: there is enough information in the wire
210     // format to skip the whole field but not enough to know how to parse what's
211     // inside. For that we go to the schema.
212     type = static_cast<WireFormatLite::FieldType>(field_desc->type());
213     is_repeated = field_desc->is_repeated();
214   }
215 
216   // Disable copy and move.
217   FieldInfo(const FieldInfo&) = delete;
218   FieldInfo& operator=(const FieldInfo&) = delete;
219 
220   // Internally we sort field descriptors by wire number for fast lookup. In
221   // general this is different from the order given by the user. Output_index
222   // gives the index into the field_names and output_types attributes and into
223   // the output tensor list.
224   int output_index = -1;
225 
226   // This is a cache of the relevant fields from `FieldDescriptorProto`. This
227   // was added after noticing that FieldDescriptor->type() was using 6% of the
228   // cpu profile.
229   WireFormatLite::FieldType type;
230   int number;
231   bool is_repeated;
232   DefaultValue default_value;
233 };
234 
235 // A CountCollector counts sizes of repeated and optional fields in a proto.
236 //
237 // Each field is tracked by a single CountCollector instance. The instance
238 // manages a single count, which is stored as a pointer (it is intended to be a
239 // reference to the `sizes` output which is being filled in). The pointer is
240 // passed in at initialization.
241 //
242 // Counting is done as a separate pass in order to allocate output tensors all
243 // at once. This allows the TensorFlow runtime to optimize allocation for the
244 // consumer, while removing the need for copying inside this op. After this
245 // pass, the DenseCollector class (below) gathers the data: it is more complex
246 // and provides better motivation for the API here.
247 class CountCollector {
248  public:
249   CountCollector() = delete;
250 
251   // The count may be stored inside an Eigen Tensor to eliminate copying.
CountCollector(int32 * count)252   explicit CountCollector(int32* count) : count_ptr_(count) {}
253 
254   // Reads (in this case counts) a single value.
ReadValue(CodedInputStream * input,const FieldInfo & field)255   Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
256     // Only repeated fields can have count > 1.
257     if (*count_ptr_ == 0 || field.is_repeated) {
258       (*count_ptr_)++;
259     }
260     // We expect a wire type based on the schema field_type, to allow a little
261     // more checking.
262     if (!SkipValue(input, field)) {
263       return errors::DataLoss("ReadValue: Failed skipping field when counting");
264     }
265     return Status::OK();
266   }
267 
268   // Reads (in this case counts) a length-delimited list of values.
ReadPackedValues(CodedInputStream * input,const FieldInfo & field,size_t buf_size)269   Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
270                           size_t buf_size) {
271     if (buf_size == 0) {
272       return Status::OK();
273     }
274 
275     const void* tmpbuf;
276     int unused_max_buf_size;
277 
278     input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size);
279     // This is safe because the underlying storage for the CodedInputStream is
280     // owned by the input tensor. If it were a Cord or file-backed stream this
281     // pointer would go stale after the bytes were skipped.
282     const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf);
283 
284     // Important: we skipped the input->{Push,Pop}Limit() calls for speed,
285     // so the bounds check on buf_size inside Skip() is critical, and
286     // must be done before scanning the contents.
287     if (!input->Skip(buf_size)) {
288       return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
289     }
290 
291     // Dispatch to the appropriately typed field reader based on the schema
292     // type.
293     Status st;
294     switch (field.type) {
295       case WireFormatLite::TYPE_DOUBLE:
296         st = CountPackedFixed<double>(buf, buf_size);
297         break;
298       case WireFormatLite::TYPE_FLOAT:
299         st = CountPackedFixed<float>(buf, buf_size);
300         break;
301       case WireFormatLite::TYPE_INT64:
302         st = CountPackedVarint(buf, buf_size);
303         break;
304       case WireFormatLite::TYPE_UINT64:
305         st = CountPackedVarint(buf, buf_size);
306         break;
307       case WireFormatLite::TYPE_INT32:
308         st = CountPackedVarint(buf, buf_size);
309         break;
310       case WireFormatLite::TYPE_FIXED64:
311         st = CountPackedFixed<uint64>(buf, buf_size);
312         break;
313       case WireFormatLite::TYPE_FIXED32:
314         st = CountPackedFixed<uint32>(buf, buf_size);
315         break;
316       case WireFormatLite::TYPE_BOOL:
317         st = CountPackedVarint(buf, buf_size);
318         break;
319       case WireFormatLite::TYPE_STRING:
320         st = errors::DataLoss("TYPE_STRING encountered as packed");
321         break;
322       case WireFormatLite::TYPE_GROUP:
323         st = errors::DataLoss("TYPE_GROUP encountered as packed");
324         break;
325       case WireFormatLite::TYPE_MESSAGE:
326         st = errors::DataLoss("TYPE_MESSAGE encountered as packed");
327         break;
328       case WireFormatLite::TYPE_BYTES:
329         st = errors::DataLoss("TYPE_BYTES encountered as packed");
330         break;
331       case WireFormatLite::TYPE_UINT32:
332         st = CountPackedVarint(buf, buf_size);
333         break;
334       case WireFormatLite::TYPE_ENUM:
335         st = CountPackedVarint(buf, buf_size);
336         break;
337       case WireFormatLite::TYPE_SFIXED32:
338         st = CountPackedFixed<int32>(buf, buf_size);
339         break;
340       case WireFormatLite::TYPE_SFIXED64:
341         st = CountPackedFixed<int64>(buf, buf_size);
342         break;
343       case WireFormatLite::TYPE_SINT32:
344         st = CountPackedVarint(buf, buf_size);
345         break;
346       case WireFormatLite::TYPE_SINT64:
347         st = CountPackedVarint(buf, buf_size);
348         break;
349         // default: intentionally omitted in order to enable static checking.
350     }
351     if (!st.ok()) {
352       return st;
353     }
354 
355     if (!field.is_repeated && *count_ptr_ > 1) {
356       *count_ptr_ = 1;
357     }
358     return Status::OK();
359   }
360 
361  private:
362   // Skips a length-delimited value.
SkipBytes(CodedInputStream * input)363   static bool SkipBytes(CodedInputStream* input) {
364     uint32 length;
365     if (!input->ReadVarint32(&length)) {
366       return false;
367     }
368     return input->Skip(length);
369   }
370 
371   // Counts the number of packed varints in an array. The end of a varint is
372   // signaled by a value < 0x80, so counting them requires parsing the
373   // bytestream. It is the caller's responsibility to ensure that len > 0.
CountPackedVarint(const uint8 * buf,size_t len)374   Status CountPackedVarint(const uint8* buf, size_t len) {
375     const uint8* bound = buf + len;
376     int count;
377 
378     // The last byte in a valid encoded varint is guaranteed to have the high
379     // bit unset. We rely on this property to prevent ReadVarint64FromArray from
380     // going out of bounds, so validate the end of the buf before scanning
381     // anything.
382     if (bound[-1] & 0x80) {
383       return errors::DataLoss("Corrupt packed varint");
384     }
385 
386     // Now we can trust ReadVarint64FromArray to stay in bounds.
387     for (count = 0; buf < bound; ++count) {
388       uint64 temp;
389       bool ok;
390       buf = internal::ReadVarint64FromArray(buf, &ok, &temp);
391       if (!ok) {
392         return errors::DataLoss("Corrupt packed varint");
393       }
394     }
395 
396     *count_ptr_ += count;
397     return Status::OK();
398   }
399 
400   // Counts the number of fixed-size values in a packed field. This can be done
401   // without actually parsing anything.
402   template <typename T>
CountPackedFixed(const uint8 * unused_buf,size_t len)403   Status CountPackedFixed(const uint8* unused_buf, size_t len) {
404     int count = len / sizeof(T);
405     if (count * sizeof(T) != len) {
406       return errors::DataLoss(
407           "Illegal data length for packed fixed-size type: ", len);
408     }
409     *count_ptr_ += len / sizeof(T);
410     return Status::OK();
411   }
412 
413   // Skips a single value in the input stream. Dispatches to the appropriately
414   // typed field skipper based on the schema type tag. This is not as permissive
415   // as just handling the wire type.
SkipValue(CodedInputStream * input,const FieldInfo & field)416   static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
417     uint32 tmp32;
418     protobuf_uint64 tmp64;
419     switch (field.type) {
420       case WireFormatLite::TYPE_DOUBLE:
421         return input->ReadLittleEndian64(&tmp64);
422       case WireFormatLite::TYPE_FLOAT:
423         return input->ReadLittleEndian32(&tmp32);
424       case WireFormatLite::TYPE_INT64:
425         return input->ReadVarint64(&tmp64);
426       case WireFormatLite::TYPE_UINT64:
427         return input->ReadVarint64(&tmp64);
428       case WireFormatLite::TYPE_INT32:
429         return input->ReadVarint32(&tmp32);
430       case WireFormatLite::TYPE_FIXED64:
431         return input->ReadLittleEndian64(&tmp64);
432       case WireFormatLite::TYPE_FIXED32:
433         return input->ReadLittleEndian32(&tmp32);
434       case WireFormatLite::TYPE_BOOL:
435         return input->ReadVarint32(&tmp32);
436       case WireFormatLite::TYPE_STRING:
437         return SkipBytes(input);
438       case WireFormatLite::TYPE_GROUP:
439         return WireFormatLite::SkipField(
440             input, WireFormatLite::MakeTag(
441                        field.number, WireFormatLite::WIRETYPE_START_GROUP));
442       case WireFormatLite::TYPE_MESSAGE:
443         return SkipBytes(input);
444       case WireFormatLite::TYPE_BYTES:
445         return SkipBytes(input);
446       case WireFormatLite::TYPE_UINT32:
447         return input->ReadVarint32(&tmp32);
448       case WireFormatLite::TYPE_ENUM:
449         return input->ReadVarint32(&tmp32);
450       case WireFormatLite::TYPE_SFIXED32:
451         return input->ReadLittleEndian32(&tmp32);
452       case WireFormatLite::TYPE_SFIXED64:
453         return input->ReadLittleEndian64(&tmp64);
454       case WireFormatLite::TYPE_SINT32:
455         return input->ReadVarint32(&tmp32);
456       case WireFormatLite::TYPE_SINT64:
457         return input->ReadVarint64(&tmp64);
458         // default: intentionally omitted in order to enable static checking.
459     }
460   }
461 
462   int32* count_ptr_ = nullptr;
463 };
464 
465 // A DenseCollector accumulates values from a proto into a tensor.
466 //
467 // There is an instance of DenseCollector for each field of each proto. The
468 // DenseCollector deserializes the value from the wire directly into the
469 // preallocated output Tensor.
470 //
471 // This class is named DenseCollector because in the future there should be a
472 // SparseCollector that accumulates field data into sparse tensors if the user
473 // requests it.
474 class DenseCollector {
475  public:
476   DenseCollector() = delete;
477 
478   // A DenseCollector applies to one field of a serialized message.
479   // Note that default_value.dtype is the type of the output tensor.
DenseCollector(uint8 * datap,DefaultValue default_value,int max_repeat_count)480   DenseCollector(uint8* datap, DefaultValue default_value, int max_repeat_count)
481       : datap_(datap),
482         default_value_(default_value),
483         max_repeat_count_(max_repeat_count) {}
484 
485   // Reads a value from the input stream and stores it.
486   //
487   // Always inlining gave a ~50% speedup on microbenchmarks at one point.
488   // TODO(nix): try removing it to see if that still holds.
489   // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE
ReadValue(CodedInputStream * input,const FieldInfo & field)490   Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
491     // For required and optional fields, we overwrite values[0] with
492     // the latest one in the wire stream.
493     // See https://developers.google.com/protocol-buffers/docs/encoding#optional
494     // Only for repeated fields do we advance the next_repeat_index_ past 1.
495     // TODO(nix): to handle oneof we must also zero out any previous values
496     //  seen on the wire.
497     int32 index = 0;
498     if (field.is_repeated) {
499       index = next_repeat_index_;
500     }
501     next_repeat_index_ = index + 1;
502 
503     return internal::ReadValue(input, field.type, field.number,
504                                default_value_.dtype, index, datap_);
505   }
506 
507   // Reads and stores a length-delimited list of values.
ReadPackedValues(CodedInputStream * input,const FieldInfo & field,const size_t buf_size)508   Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
509                           const size_t buf_size) {
510     const void* buf;
511     int unused_max_buf_size;
512     input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size);
513     // This is safe because the underlying storage for the CodedInputStream is
514     // owned by the input tensor. If it were a Cord or file-backed stream this
515     // pointer would go stale after the bytes were skipped.
516     if (!input->Skip(buf_size)) {
517       return errors::DataLoss(
518           "ReadPackedValues: Skipping packed field failed.  Field tag: ",
519           field.number);
520     }
521 
522     // Setting stride=0 causes new values to overwrite old ones for
523     // non-repeated fields.
524     const int stride = field.is_repeated ? 1 : 0;
525 
526     if (next_repeat_index_ >= max_repeat_count_) {
527       return errors::DataLoss(
528           "ReadPackedValues: Tried to write more entries than allowed.  "
529           "Field tag: ",
530           field.number, ", Max entries allowed: ", max_repeat_count_);
531     } else {
532       return internal::ReadPackedFromArray(buf, buf_size, field.type,
533                                            field.number, default_value_.dtype,
534                                            stride, &next_repeat_index_, datap_);
535     }
536   }
537 
538   // Fills in any missing values in the output array with defaults. Dispatches
539   // to the appropriately typed field default based on the runtime type tag.
FillWithDefaults()540   Status FillWithDefaults() {
541     switch (default_value_.dtype) {
542       case DataType::DT_BOOL:
543         return FillDefault<bool>(default_value_.value.v_bool);
544       case DataType::DT_FLOAT:
545         return FillDefault<float>(default_value_.value.v_float);
546       case DataType::DT_DOUBLE:
547         return FillDefault<double>(default_value_.value.v_double);
548       case DataType::DT_INT8:
549         return FillDefault<int8>(default_value_.value.v_int8);
550       case DataType::DT_INT32:
551         return FillDefault<int32>(default_value_.value.v_int32);
552       case DataType::DT_INT64:
553         return FillDefault<int64>(default_value_.value.v_int64);
554       case DataType::DT_STRING:
555         return FillDefault<string>(default_value_.value.v_string);
556       case DataType::DT_UINT8:
557         return FillDefault<uint8>(default_value_.value.v_uint8);
558       case DataType::DT_UINT32:
559         return FillDefault<uint32>(default_value_.value.v_uint32);
560       case DataType::DT_UINT64:
561         return FillDefault<uint64>(default_value_.value.v_uint64);
562       default:
563         // There are many tensorflow dtypes not handled here, but they
564         // should not come up unless type casting is added to the Op.
565         // Chaining with tf.cast() should do the right thing until then.
566         return errors::DataLoss("Failed filling defaults for ",
567                                 DataTypeString(default_value_.dtype));
568     }
569   }
570 
571  private:
572   // Fills empty values in the dense representation with a default value. This
573   // uses next_repeat_index_ which counts the number of parsed values for the
574   // field.
575   template <class T>
FillDefault(const T & default_value)576   Status FillDefault(const T& default_value) {
577     for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
578       reinterpret_cast<T*>(datap_)[i] = default_value;
579     }
580     return Status::OK();
581   }
582 
583   int32 next_repeat_index_ = 0;
584 
585   // This is a pointer to data_[message_index_]. There is no bounds checking at
586   // this level: we computed the max repeat size for each field in
587   // CountCollector and use the same code to traverse it here, so we are
588   // guaranteed not to be called for more items than we have allocated space.
589   void* const datap_ = nullptr;
590 
591   const DefaultValue default_value_;
592   const int max_repeat_count_ = 0;
593 };
594 
595 class DecodeProtoOp : public OpKernel {
596  public:
DecodeProtoOp(OpKernelConstruction * context)597   explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
598     string descriptor_source;
599     OP_REQUIRES_OK(context,
600                    context->GetAttr("descriptor_source", &descriptor_source));
601 
602     // We always get back a desc_pool, but we may not own it. If we own it,
603     // owned_desc_pool_ will be filled in.
604     DescriptorPool const* desc_pool;
605     OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
606                                               &desc_pool, &owned_desc_pool_));
607 
608     string message_type;
609     OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
610 
611     const Descriptor* message_desc =
612         desc_pool->FindMessageTypeByName(message_type);
613     OP_REQUIRES(context, message_desc != nullptr,
614                 errors::InvalidArgument("No descriptor found for message type ",
615                                         message_type));
616 
617     std::vector<string> field_names;
618     OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names));
619     std::vector<DataType> output_types;
620     OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types));
621     OP_REQUIRES(
622         context, field_names.size() == output_types.size(),
623         errors::InvalidArgument("field_names and output_types attributes must "
624                                 "have the same length"));
625 
626     // Gather the field descriptors and check that requested output types match.
627     int field_index = 0;
628     std::vector<const FieldDescriptor*> field_descs;
629     std::vector<const FieldDescriptor*> exts;
630     absl::flat_hash_map<string, const FieldDescriptor*> ext_name_to_field;
631     std::vector<const FieldDescriptor*>::iterator ext_it = exts.begin();
632     for (const string& name : field_names) {
633       auto fd = message_desc->FindFieldByName(name);
634       if (fd == nullptr) {
635         // If field can't be found in original message, try to find a matching
636         // extension (by its full_name). First check a hashmap for a matching
637         // extension, and if not found, then iterate through available
638         // extensions to find a match (updating the hashmap while iterating.)
639         auto lookup_result = ext_name_to_field.find(name);
640         if (lookup_result != ext_name_to_field.end()) {
641           fd = lookup_result->second;
642         } else {
643           if (ext_it == exts.begin()) {
644             desc_pool->FindAllExtensions(message_desc, &exts);
645             ext_it = exts.begin();
646           }
647           while (ext_it != exts.end()) {
648             auto ext_name = (*ext_it)->full_name();
649             auto ext_field = *ext_it;
650             ++ext_it;
651 
652             ext_name_to_field.insert({ext_name, ext_field});
653             if (ext_name == name) {
654               fd = ext_field;
655               break;
656             }
657           }
658         }
659       }
660       OP_REQUIRES(context, fd != nullptr,
661                   errors::InvalidArgument("Unknown field: ", name,
662                                           " in message type ", message_type));
663       OP_REQUIRES(
664           context,
665           proto_utils::IsCompatibleType(fd->type(), output_types[field_index]),
666           // Many TensorFlow types don't have corresponding proto types and the
667           // user will get an error if they are requested. It would be nice to
668           // allow conversions here, but tf.cast already exists so we don't
669           // duplicate the functionality.
670           errors::InvalidArgument("Unexpected output type for ",
671                                   fd->full_name(), ": ", fd->cpp_type(), " to ",
672                                   output_types[field_index]));
673 
674       field_index++;
675       field_descs.push_back(fd);
676     }
677 
678     // Internally we want the field_descs sorted by their number on the wire.
679     // But the output tensors are allocated in the order given by the caller.
680     // Build a mapping i->j, where field_descs[i] corresponds to outputs[j].
681     std::vector<int> output_indices;
682     output_indices.reserve(field_names.size());
683     for (int i = 0; i < field_names.size(); i++) {
684       output_indices.push_back(i);
685     }
686     std::sort(output_indices.begin(), output_indices.end(),
687               [field_descs](int a, int b) {
688                 return field_descs[a]->number() < field_descs[b]->number();
689               });
690 
691     // Now store the fields in sorted order.
692     for (int i = 0; i < field_names.size(); i++) {
693       const int output_index = output_indices[i];
694       const DataType dtype = output_types[output_index];
695       const FieldDescriptor* field_descriptor = field_descs[output_index];
696       DefaultValue default_value;
697       OP_REQUIRES_OK(context, InitDefaultValueFromFieldDescriptor(
698                                   dtype, field_descriptor, &default_value));
699       fields_.push_back(
700           MakeUnique<FieldInfo>(field_descriptor, output_index, default_value));
701     }
702 
703     message_prototype_ = message_factory_.GetPrototype(message_desc);
704     OP_REQUIRES(context, message_prototype_ != nullptr,
705                 errors::InvalidArgument("Couldn't get prototype message: ",
706                                         message_desc->full_name()));
707     string format;
708     OP_REQUIRES_OK(context, context->GetAttr("message_format", &format));
709     OP_REQUIRES(
710         context, format == "binary" || format == "text",
711         errors::InvalidArgument("format must be one of binary or text"));
712     is_binary_ = format == "binary";
713 
714     // Enable the initial protobuf sanitizer, which is much more expensive than
715     // the decoder.
716     // TODO(nix): Remove this once the fast decoder has passed security review.
717     OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
718   }
719 
Compute(OpKernelContext * ctx)720   void Compute(OpKernelContext* ctx) override {
721     const Tensor& buf_tensor = ctx->input(0);
722     int message_count = buf_tensor.NumElements();
723     OP_REQUIRES(ctx, message_count >= 1,
724                 errors::InvalidArgument(
725                     "Bufs argument must contain at least one value"));
726 
727     int field_count = fields_.size();
728 
729     // Save the argument shape for later, then flatten the input Tensor since we
730     // are working componentwise. We will restore the same shape in the returned
731     // Tensor.
732     const TensorShape& shape_prefix = buf_tensor.shape();
733 
734     TensorShape sizes_shape = shape_prefix;
735     sizes_shape.AddDim(field_count);
736     Tensor* sizes_tensor = nullptr;
737     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
738 
739     // This is used to allocate binary bufs if used. It serves only to define
740     // memory ownership.
741     std::vector<string> tmp_binary_bufs(message_count);
742 
743     // These are the actual buffers to use, which may be in tmp_binary_bufs
744     // or may be pointers into the buf_tensor. Either way they are not owned
745     // here.
746     std::vector<const string*> bufs;
747 
748     if (is_binary_ && !sanitize_) {
749       // Fast path.
750       for (int mi = 0; mi < message_count; ++mi) {
751         const string* buf = &buf_tensor.flat<string>()(mi);
752         bufs.push_back(buf);
753       }
754     } else {
755       // We will have to allocate a copy, either to convert from text to binary
756       // or to sanitize a binary proto.
757       for (int mi = 0; mi < message_count; ++mi) {
758         ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
759                            &tmp_binary_bufs[mi]);
760         if (!ctx->status().ok()) {
761           return;
762         }
763         bufs.push_back(&tmp_binary_bufs[mi]);
764       }
765     }
766 
767     // Walk through all the strings in the input tensor, counting the number of
768     // fields in each. We can't allocate our actual output Tensor until we know
769     // the maximum repeat count, so we do a first pass through the serialized
770     // proto just counting fields. We always allocate at least one value so that
771     // optional fields are populated with default values - this avoids a TF
772     // conditional when handling the output data. The caller can distinguish
773     // between real data and defaults using the repeat count matrix that is
774     // returned by decode_proto.
775     std::vector<int32> max_sizes(field_count, 1);
776     for (int mi = 0; mi < message_count; ++mi) {
777       CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
778       if (!ctx->status().ok()) {
779         return;
780       }
781     }
782 
783     // Allocate the output tensors now that we've seen the max size.
784     // TODO(nix): Use allocate_output_or_forward_input for the largest
785     //   output tensor. This can avoid one large allocation by re-using
786     //   the memory of the input tensor.
787     std::vector<Tensor*> outputs(field_count);
788     for (int fi = 0; fi < field_count; ++fi) {
789       TensorShape flat_shape = {static_cast<int64>(message_count),
790                                 max_sizes[fi]};
791       TensorShape out_shape = shape_prefix;
792       out_shape.AddDim(max_sizes[fi]);
793 
794       // Surprisingly we don't specify the types from the output_types
795       // attribute: that is done for us based on the Op declaration:
796       //  REGISTER_OP(...)
797       //    .Attr("output_types: list(type) >= 0")
798       //    .Output("values: output_types")
799       OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1,
800                                                out_shape, &outputs[fi]));
801     }
802 
803     // Make the second pass through the serialized proto, decoding into
804     // preallocated tensors.
805     AccumulateFields(ctx, bufs, outputs);
806   }
807 
808  private:
809   // Copy a serialized message to binary, e.g. to handle text proto inputs.
ReserializeMessage(OpKernelContext * ctx,const string & buf,string * binary_buf)810   void ReserializeMessage(OpKernelContext* ctx, const string& buf,
811                           string* binary_buf) {
812     // Handle text protos by translating them to binary.
813     std::unique_ptr<Message> message(message_prototype_->New());
814     OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed"));
815 
816     if (is_binary_) {
817       // If we get here we are sanitizing the input protobuf by parsing
818       // and reserializing it with a trusted (but very slow) library.
819       OP_REQUIRES(ctx, message->ParseFromString(buf),
820                   errors::DataLoss("Unable to parse binary protobuf"));
821     } else {
822       OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()),
823                   errors::DataLoss("Unable to parse text protobuf"));
824     }
825 
826     OP_REQUIRES(ctx, message->SerializeToString(binary_buf),
827                 errors::DataLoss("Unable to reserialize text proto as binary"));
828   }
829 
830   // Count the number of occurrences of each requested field in a message batch.
CountFields(OpKernelContext * ctx,int message_index,const string & buf,Tensor * sizes_tensor,std::vector<int32> * max_sizes)831   void CountFields(OpKernelContext* ctx, int message_index, const string& buf,
832                    Tensor* sizes_tensor, std::vector<int32>* max_sizes) {
833     int field_count = fields_.size();
834 
835     CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
836                            buf.size());
837 
838     std::vector<int32> field_sizes(field_count, 0);
839     std::vector<CountCollector> counters;
840     counters.reserve(field_count);
841     for (int i = 0; i < field_count; i++) {
842       counters.emplace_back(&field_sizes[i]);
843     }
844 
845     Status st = Collect(&input, &counters);
846     if (st.ok() && !input.ConsumedEntireMessage()) {
847       st = errors::DataLoss("CountFields: Failed to consume entire buffer");
848     }
849     if (kFailOnDecodeError) {
850       OP_REQUIRES_OK(ctx, st);  // NOLINT
851     }
852     if (!st.ok()) {
853       // This code suppresses the corrupt proto, treating it as empty
854       // to avoid crashing the process.
855       LOG(WARNING) << "Proto counting error for message type " << message_type_
856                    << ": " << st;
857 
858       for (int fi = 0; fi < field_count; fi++) {
859         field_sizes[fi] = 0;
860       }
861       // Finished decoding this message.
862       return;
863     }
864 
865     // Update the size tensor and max repeat size for each field.
866     auto sizes = sizes_tensor->flat_inner_dims<int32>();
867     for (int fi = 0; fi < field_count; fi++) {
868       int32 size = field_sizes[fi];
869       sizes(message_index, fields_[fi]->output_index) = size;
870       if ((*max_sizes)[fi] < size) {
871         (*max_sizes)[fi] = size;
872       }
873     }
874   }
875 
876   // Parse fields from a serialized message into preallocated tensors.
AccumulateFields(OpKernelContext * ctx,const std::vector<const string * > & bufs,std::vector<Tensor * > outputs)877   void AccumulateFields(OpKernelContext* ctx,
878                         const std::vector<const string*>& bufs,
879                         std::vector<Tensor*> outputs) {
880     struct TensorInfo {
881       explicit TensorInfo(Tensor* tensor) {
882         // Note that we can decode only max_repeat_count values before overflow.
883         // No other bounds checking is done for repeated fields. For
884         // optional fields there is a check to make sure that only the last
885         // value on the wire appears in the output tensor.
886         dtype = tensor->dtype();
887         last_dim_size = tensor->dim_size(tensor->dims() - 1);
888 
889         if (dtype != DT_STRING) {
890           const int element_size = DataTypeSize(dtype);
891           CHECK_GT(element_size, 0);
892           stride = last_dim_size * element_size;
893 
894           const int64 flatshape[1] = {tensor->NumElements() * element_size};
895           data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
896         } else {
897           // DataTypeSize() returns 0 for string types.
898           stride = last_dim_size * sizeof(string);
899           data = reinterpret_cast<uint8*>(tensor->flat<string>().data());
900         }
901       }
902 
903       DataType dtype;
904       int last_dim_size;
905       int stride;
906       uint8* data;
907     };
908 
909     int field_count = fields_.size();
910 
911     std::vector<TensorInfo> tensors;
912     tensors.reserve(field_count);
913     for (int fi = 0; fi < field_count; fi++) {
914       tensors.emplace_back(outputs[fi]);
915     }
916 
917     for (int message_index = 0; message_index < bufs.size(); ++message_index) {
918       const string& buf = *bufs[message_index];
919 
920       std::vector<DenseCollector> collectors;
921       collectors.reserve(field_count);
922       for (int output_index = 0; output_index < field_count; ++output_index) {
923         const TensorInfo& info = tensors[output_index];
924         const FieldInfo* field_info = fields_[output_index].get();
925         DCHECK(field_info != nullptr);
926         const DefaultValue default_value = field_info->default_value;
927         collectors.emplace_back(info.data + message_index * info.stride,
928                                 default_value, info.last_dim_size);
929       }
930 
931       // Fill in output tensors from the wire.
932       CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
933                              buf.size());
934       Status st = Collect(&input, &collectors);
935       if (st.ok() && !input.ConsumedEntireMessage()) {
936         st = errors::DataLoss(
937             "AccumulateFields: Failed to consume entire buffer");
938       }
939       if (kFailOnDecodeError) {
940         OP_REQUIRES_OK(ctx, st);  // NOLINT
941       }
942       if (!st.ok()) {
943         // This code suppresses the corrupt proto, treating it as empty
944         // to avoid crashing training.
945         LOG(WARNING) << "Proto counting error for message type "
946                      << message_type_ << ": " << st;
947       }
948 
949       // Fill the remainder of the dense outputs with default values.
950       for (auto& collector : collectors) {
951         OP_REQUIRES_OK(ctx, collector.FillWithDefaults());
952       }
953     }
954   }
955 
956   // Look up the FieldDescriptor for a particular field number.
LookupField(int field_number,int * field_index)957   bool LookupField(int field_number, int* field_index) {
958     // Look up the FieldDescriptor using linear search.
959     //
960     // TODO(nix): this could be sped up with binary search, but we are
961     // already way off the fastpath at this point. If you see a hotspot
962     // here, somebody is sending you very inefficient protos.
963     for (int fi = fields_.size() - 1; fi >= 0; fi--) {
964       if (field_number == fields_[fi]->number) {
965         *field_index = fi;
966         return true;
967       }
968     }
969     return false;
970   }
971 
972   // Traverses a serialized protobuf, dispatching values to the collectors.
973   template <class CollectorClass>
Collect(CodedInputStream * input,std::vector<CollectorClass> * collectors)974   Status Collect(CodedInputStream* input,
975                  std::vector<CollectorClass>* collectors) {
976     int last_good_field_index = -1;
977     bool fields_disordered = false;
978     int prev_field_number = -1;
979     int field_number = -1;
980     int last_good_field_number = -1;
981     int next_good_field_number = fields_[0]->number;
982 
983     // The 'tag' variable should always be treated as tainted.
984     for (uint32 tag = input->ReadTag();
985          tag != 0 && WireFormatLite::GetTagWireType(tag) !=
986                          WireFormatLite::WIRETYPE_END_GROUP;
987          tag = input->ReadTag(), prev_field_number = field_number) {
988       field_number = WireFormatLite::GetTagFieldNumber(tag);
989       const FieldInfo* field = nullptr;
990 
991       // This takes advantage of the sorted field numbers in most serialized
992       // protos: it tries the next expected field first rather than doing
993       // a lookup by field number.
994       //
995       // TODO(nix): haberman@ suggests a hybrid approach with a lookup table
996       // for small field numbers and a hash table for larger ones. This would
997       // be a simpler approach that should offer comparable speed in most
998       // cases.
999       if (field_number == last_good_field_number) {
1000         field = fields_[last_good_field_index].get();
1001       } else {
1002         if (field_number < prev_field_number) {
1003           fields_disordered = true;
1004         }
1005 
1006         // If fields are out of order, fall back to slow lookup.
1007         if (fields_disordered) {
1008           int field_index;
1009           if (LookupField(field_number, &field_index)) {
1010             field = fields_[field_index].get();
1011             last_good_field_index = field_index;
1012           }
1013         } else {
1014           // If we see a field that is past the next field we want, it was
1015           // empty. Look for the one after that. Repeat until we run out of
1016           // fields that we care about.
1017           while (field_number >= next_good_field_number) {
1018             if (field_number == next_good_field_number) {
1019               last_good_field_number = field_number;
1020               field = fields_[last_good_field_index + 1].get();
1021             }
1022 
1023             // Start looking for the field after the current one.
1024             ++last_good_field_index;
1025             if (last_good_field_index < fields_.size() - 1) {
1026               next_good_field_number =
1027                   fields_[last_good_field_index + 1]->number;
1028             } else {
1029               // Saw something past the last field we care about. Continue
1030               // parsing the message just in case there are disordered fields
1031               // later, but any remaining ordered fields will have no effect.
1032               next_good_field_number = INT_MAX;
1033             }
1034           }
1035         }
1036       }
1037 
1038       if (!field) {
1039         // Unknown and unrequested fields are skipped.
1040         if (!WireFormatLite::SkipField(input, tag)) {
1041           return errors::DataLoss("Failed skipping unrequested field");
1042         }
1043         continue;
1044       }
1045 
1046       Status st = CollectField(*field, WireFormatLite::GetTagWireType(tag),
1047                                input, &(*collectors)[last_good_field_index]);
1048       if (!st.ok()) {
1049         return st;
1050       }
1051     }
1052     return Status::OK();
1053   }
1054 
1055   // Collects values for a single field.
1056   template <class CollectorClass>
CollectField(const FieldInfo & field,WireFormatLite::WireType wire_type,CodedInputStream * input,CollectorClass * collector)1057   Status CollectField(const FieldInfo& field,
1058                       WireFormatLite::WireType wire_type,
1059                       CodedInputStream* input, CollectorClass* collector) {
1060     // The wire format library defines the same constants used in
1061     // descriptor.proto. This static_cast is safe because they are guaranteed to
1062     // stay in sync.
1063     //
1064     // We need the field type from the FieldDescriptor here because the wire
1065     // format doesn't tell us anything about what happens inside a packed
1066     // repeated field: there is enough information in the wire format to skip
1067     // the whole field but not enough to know how to parse what's inside. For
1068     // that we go to the schema.
1069     WireFormatLite::WireType schema_wire_type =
1070         WireFormatLite::WireTypeForFieldType(field.type);
1071 
1072     // Handle packed repeated fields. SkipField would skip the whole
1073     // length-delimited blob without letting us count the values, so we have to
1074     // scan them ourselves.
1075     if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
1076         schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
1077       // Handle packed repeated primitives.
1078       int length;
1079       if (!input->ReadVarintSizeAsInt(&length)) {
1080         return errors::DataLoss("CollectField: Failed reading packed size");
1081       }
1082       return collector->ReadPackedValues(input, field, length);
1083     }
1084 
1085     // Read ordinary values, including strings, bytes, and messages.
1086     if (wire_type != schema_wire_type) {
1087       if (!WireFormatLite::SkipField(
1088               input, WireFormatLite::MakeTag(field.number, wire_type))) {
1089         return errors::DataLoss(
1090             "CollectField: Failed skipping malformed field");
1091       }
1092       return Status::OK();
1093     }
1094     return collector->ReadValue(input, field);
1095   }
1096 
1097   string message_type_;
1098   // Note that fields are sorted by increasing field number, which is not in
1099   // general the order given by the user-specified field_names and output_types
1100   // Op attributes.
1101   std::vector<std::unique_ptr<const FieldInfo>> fields_;
1102 
1103   // Owned_desc_pool_ is null when using descriptor_source=local.
1104   std::unique_ptr<DescriptorPool> owned_desc_pool_;
1105   DynamicMessageFactory message_factory_;
1106   const Message* message_prototype_;
1107 
1108   // True if decoding binary format, false if decoding text format.
1109   bool is_binary_;
1110 
1111   // True if the protos should be sanitized before parsing. Enables the initial
1112   // protobuf sanitizer, which is much more expensive than the decoder. The flag
1113   // defaults to true but can be set to false for trusted sources.
1114   //
1115   // TODO(nix): Flip the default to false when the fast decoder has passed
1116   // security review.
1117   bool sanitize_;
1118 
1119   TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
1120 };
1121 
1122 REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU),
1123                         DecodeProtoOp);
1124 
1125 }  // namespace
1126 }  // namespace tensorflow
1127