• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 
24 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
25 #include "tensorflow/core/example/feature.pb.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/fingerprint.h"
32 #include "tensorflow/core/public/session_options.h"
33 #include "tensorflow/core/util/example_proto_fast_parsing.h"
34 #include "tensorflow/core/util/presized_cuckoo_map.h"
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/kernels/internal/tensor.h"
37 #include "tensorflow/lite/kernels/kernel_util.h"
38 #include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
39 #include "tensorflow/lite/mutable_op_resolver.h"
40 #include "tensorflow/lite/string_util.h"
41 
42 namespace tflite {
43 namespace ops {
44 namespace custom {
45 namespace parse_example {
46 namespace {
47 
48 namespace tf = ::tensorflow;
49 using tf::Status;
50 using tf::StringPiece;
51 using tf::tstring;
52 using tf::example::CopyOrMoveBlock;
53 using tf::example::FastParseExampleConfig;
54 using tf::example::GetListFromBuffer;
55 using tf::example::LimitedArraySlice;
56 using tf::example::ParseExample;
57 using tf::example::SeededHasher;
58 using tf::example::SmallVector;
59 using tf::example::SparseBuffer;
60 using tf::example::Type;
61 using tf::example::parsed::Example;
62 
63 using ConfigIndex = tf::PresizedCuckooMap<std::pair<int32_t, Type>>;
64 
65 struct TfLiteResult {
66   std::vector<TfLiteTensor*> dense_values;
67   std::vector<TfLiteTensor*> sparse_values;
68   std::vector<TfLiteTensor*> sparse_indices;
69   std::vector<TfLiteTensor*> sparse_shapes;
70   std::map<int, tf::Tensor> dense_tensors;
71 };
72 
73 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const FastParseExampleConfig & config,std::vector<SparseBuffer> & varlen_dense_buffers,TfLiteTensor * values)74 void FillAndCopyVarLen(const int d, const size_t num_elements,
75                        const size_t num_elements_per_minibatch,
76                        const FastParseExampleConfig& config,
77                        std::vector<SparseBuffer>& varlen_dense_buffers,
78                        TfLiteTensor* values) {
79   const tf::Tensor& default_value = config.dense[d].default_value;
80 
81   // Copy-fill the tensors (creating the zero/fill-padding)
82   std::fill(reinterpret_cast<T*>(values->data.raw),
83             reinterpret_cast<T*>(values->data.raw) + num_elements,
84             default_value.flat<T>()(0));
85 
86   auto data = reinterpret_cast<T*>(values->data.raw);
87 
88   const SparseBuffer& buffer = varlen_dense_buffers[d];
89   // Number of examples being stored in this buffer
90   const auto& end_indices = buffer.example_end_indices;
91   const size_t examples_in_buffer = end_indices.size();
92 
93   const auto& list = GetListFromBuffer<T>(buffer);
94   auto list_ptr = list.begin();
95 
96   size_t elements_tally = 0;
97   // Iterate through all the examples stored in this buffer.
98   for (size_t j = 0; j < examples_in_buffer; ++j) {
99     // Number of elements stored for this example.
100     const size_t num_elems = end_indices[j] - elements_tally;
101     CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
102     // Move forward this many elements in the varlen buffer.
103     list_ptr += num_elems;
104     // Move forward to the next minibatch entry in the values output.
105     data += num_elements_per_minibatch;
106     elements_tally = end_indices[j];
107   }
108   DCHECK(elements_tally == list.size());
109 }
110 
ParseExample(StringRef serialized,Example * example)111 bool ParseExample(StringRef serialized, Example* example) {
112   DCHECK(example != nullptr);
113   tf::protobuf::io::CodedInputStream stream(
114       reinterpret_cast<const uint8*>(serialized.str), serialized.len);
115   tensorflow::example::EnableAliasing(&stream);
116   return ParseExample(&stream, example);
117 }
118 
FastParseSerializedExample(StringRef serialized_example,const tstring & example_name,const size_t example_index,const FastParseExampleConfig & config,bool * quick_filter,int quick_filter_size,const std::unique_ptr<ConfigIndex> & config_index,int config_index_size,SeededHasher * hasher,std::vector<TfLiteTensor * > * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::map<absl::string_view,int> & stats,TfLiteResult * result)119 Status FastParseSerializedExample(
120     StringRef serialized_example, const tstring& example_name,
121     const size_t example_index, const FastParseExampleConfig& config,
122     bool* quick_filter, int quick_filter_size,
123     const std::unique_ptr<ConfigIndex>& config_index, int config_index_size,
124     SeededHasher* hasher, std::vector<TfLiteTensor*>* output_dense,
125     std::vector<SparseBuffer>* output_varlen_dense,
126     std::vector<SparseBuffer>* output_sparse,
127     std::map<absl::string_view, int>& stats, TfLiteResult* result) {
128   DCHECK(output_dense != nullptr);
129   tensorflow::example::parsed::Example parsed_example;
130   if (!ParseExample(serialized_example, &parsed_example)) {
131     return tf::errors::Internal("Failed to parse example");
132   }
133   std::vector<int64_t> dense_feature_last_example(config.dense.size(), -1);
134   std::vector<int64_t> sparse_feature_last_example(config.sparse.size(), -1);
135   // Handle features present in the example.
136   const size_t parsed_example_size = parsed_example.size();
137   for (size_t i = 0; i < parsed_example_size; ++i) {
138     // This is a logic that standard protobuf parsing is implementing.
139     // I.e. last entry in the map overwrites all the previous ones.
140     tensorflow::example::parsed::FeatureMapEntry& name_and_feature =
141         parsed_example[parsed_example_size - i - 1];
142     const StringPiece feature_name = name_and_feature.first;
143     tensorflow::example::parsed::Feature& feature = name_and_feature.second;
144     if (feature_name.length() >= quick_filter_size ||
145         !quick_filter[feature_name.length()]) {
146       continue;
147     }
148     const uint64_t h = (*hasher)(feature_name);
149     std::pair<int32_t, Type> d_and_type;
150     if (!config_index->Find(h, &d_and_type)) {
151       continue;
152     }
153     size_t d = d_and_type.first;
154     bool is_dense = d_and_type.second == Type::Dense;
155 
156     auto example_error = [&](StringPiece suffix) {
157       return tf::errors::Internal("Name: ", example_name,
158                                   ", Key: ", feature_name,
159                                   ", Index: ", example_index, ".  ", suffix);
160     };
161 
162     auto parse_error = [&] {
163       return example_error("Can't parse serialized Example.");
164     };
165 
166     tf::DataType example_dtype;
167     if (feature.ParseDataType(&example_dtype) != ::tensorflow::OkStatus()) {
168       return parse_error();
169     }
170     if (is_dense) {
171       if (example_dtype == tf::DT_INVALID) continue;
172 
173       dense_feature_last_example[d] = example_index;
174 
175       if (example_dtype != config.dense[d].dtype) {
176         return example_error(absl::StrCat(
177             "Data types don't match. Data type: ",
178             DataTypeString(example_dtype),
179             " but expected type: ", DataTypeString(config.dense[d].dtype)));
180       }
181       if (!config.dense[d].variable_length) {
182         TfLiteTensor* out = (*output_dense)[d];
183 
184         const std::size_t num_elements = config.dense[d].elements_per_stride;
185         const std::size_t offset = example_index * num_elements;
186 
187         auto shape_error = [&](size_t size, StringPiece type_str) {
188           return example_error(absl::StrCat(
189               "Number of ", type_str,
190               " values != expected.  "
191               "Values size:",
192               size,
193               " but output shape: ", config.dense[d].shape.DebugString()));
194         };
195 
196         switch (config.dense[d].dtype) {
197           case tf::DT_INT64: {
198             auto out_p = reinterpret_cast<int64_t*>(out->data.raw) + offset;
199             LimitedArraySlice<int64_t> slice(out_p, num_elements);
200             if (!feature.ParseInt64List(&slice)) return parse_error();
201             if (slice.EndDistance() != 0) {
202               return shape_error(num_elements - slice.EndDistance(), "int64");
203             }
204             break;
205           }
206           case tf::DT_FLOAT: {
207             auto out_p = reinterpret_cast<float*>(out->data.raw) + offset;
208             LimitedArraySlice<float> slice(out_p, num_elements);
209             if (!feature.ParseFloatList(&slice)) return parse_error();
210             if (slice.EndDistance() != 0) {
211               return shape_error(num_elements - slice.EndDistance(), "float");
212             }
213             break;
214           }
215           case tf::DT_STRING: {
216             auto& out_tensor = result->dense_tensors[d];
217             auto out_p = out_tensor.flat<tstring>().data() + offset;
218             LimitedArraySlice<tstring> slice(out_p, num_elements);
219             if (!feature.ParseBytesList(&slice)) return parse_error();
220             if (slice.EndDistance() != 0) {
221               return shape_error(num_elements - slice.EndDistance(), "bytes");
222             }
223             break;
224           }
225           default:
226             return tf::errors::Internal("Unrecognized dense type: ",
227                                         config.dense[d].dtype);
228         }
229       } else {  // if dense variable length
230         SparseBuffer& out = (*output_varlen_dense)[d];
231 
232         const std::size_t num_elements = config.dense[d].elements_per_stride;
233 
234         if (example_dtype != tf::DT_INVALID &&
235             example_dtype != config.dense[d].dtype) {
236           return example_error(absl::StrCat(
237               "Data types don't match. ",
238               "Expected type: ", DataTypeString(config.dense[d].dtype)));
239         }
240 
241         auto shape_error = [&](size_t size, StringPiece type_str) {
242           return example_error(
243               absl::StrCat("Number of ", type_str,
244                            " values is not a multiple of stride length. Saw ",
245                            size, " values but output shape is: ",
246                            config.dense[d].shape.DebugString()));
247         };
248 
249         switch (config.dense[d].dtype) {
250           case tf::DT_INT64: {
251             if (example_dtype != tf::DT_INVALID) {
252               if (!feature.ParseInt64List(&out.int64_list)) {
253                 return parse_error();
254               }
255               if (out.int64_list.size() % num_elements != 0) {
256                 return shape_error(out.int64_list.size(), "int64");
257               }
258             }
259             out.example_end_indices.push_back(out.int64_list.size());
260             break;
261           }
262           case tf::DT_FLOAT: {
263             if (example_dtype != tf::DT_INVALID) {
264               if (!feature.ParseFloatList(&out.float_list)) {
265                 return parse_error();
266               }
267               if (out.float_list.size() % num_elements != 0) {
268                 return shape_error(out.float_list.size(), "float");
269               }
270             }
271             out.example_end_indices.push_back(out.float_list.size());
272             break;
273           }
274           case tf::DT_STRING: {
275             if (example_dtype != tf::DT_INVALID) {
276               if (!feature.ParseBytesList(&out.bytes_list)) {
277                 return parse_error();
278               }
279               if (out.bytes_list.size() % num_elements != 0) {
280                 return shape_error(out.bytes_list.size(), "byte");
281               }
282             }
283             out.example_end_indices.push_back(out.bytes_list.size());
284             break;
285           }
286           default:
287             return tf::errors::Internal("Should not happen: ",
288                                         config.dense[d].dtype);
289         }
290       }
291     } else {
292       // is sparse or ragged
293       auto& last_example = sparse_feature_last_example;
294       if (last_example[d] == example_index) {
295         continue;
296       }
297       last_example[d] = example_index;
298       SparseBuffer& out = (*output_sparse)[d];
299       tf::DataType feature_dtype = config.sparse[d].dtype;
300       if (example_dtype != tf::DT_INVALID && example_dtype != feature_dtype) {
301         return tf::errors::Internal("Data types don't match:", example_dtype,
302                                     " != ", feature_dtype);
303       }
304       switch (feature_dtype) {
305         case tf::DT_INT64: {
306           if (example_dtype != tf::DT_INVALID) {
307             if (!feature.ParseInt64List(&out.int64_list)) {
308               return parse_error();
309             }
310           }
311           out.example_end_indices.push_back(out.int64_list.size());
312           break;
313         }
314         case tf::DT_FLOAT: {
315           if (example_dtype != tf::DT_INVALID) {
316             if (!feature.ParseFloatList(&out.float_list)) {
317               return parse_error();
318             }
319           }
320           out.example_end_indices.push_back(out.float_list.size());
321           break;
322         }
323         case tf::DT_STRING: {
324           if (example_dtype != tf::DT_INVALID) {
325             if (!feature.ParseBytesList(&out.bytes_list)) {
326               return parse_error();
327             }
328           }
329           out.example_end_indices.push_back(out.bytes_list.size());
330           break;
331         }
332         default:
333           return tf::errors::Internal("Should not happen: ", feature_dtype);
334       }
335     }
336   }
337   // Handle missing dense features for fixed strides.
338   for (size_t d = 0; d < config.dense.size(); ++d) {
339     if (config.dense[d].variable_length) continue;
340     if (dense_feature_last_example[d] == example_index) continue;
341     if (config.dense[d].default_value.NumElements() == 0) {
342       return tf::errors::Internal(
343           "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
344           " (data type: ", DataTypeString(config.dense[d].dtype), ")",
345           " is required but could not be found.");
346     }
347     const tf::Tensor& in = config.dense[d].default_value;
348     TfLiteTensor* out = result->dense_values[d];
349     const std::size_t num_elements = in.shape().num_elements();
350     const std::size_t offset = example_index * num_elements;
351     switch (config.dense[d].dtype) {
352       case tf::DT_INT64: {
353         std::copy_n(in.flat<int64_t>().data(), num_elements,
354                     out->data.i64 + offset);
355         break;
356       }
357       case tf::DT_FLOAT: {
358         std::copy_n(in.flat<float>().data(), num_elements,
359                     out->data.f + offset);
360         break;
361       }
362       case tf::DT_STRING: {
363         auto& out_tensor = result->dense_tensors[d];
364         std::copy_n(in.flat<tstring>().data(), num_elements,
365                     out_tensor.flat<tstring>().data() + offset);
366         break;
367       }
368       default:
369         return tf::errors::Internal("Should not happen: ",
370                                     config.dense[d].dtype);
371     }
372   }
373   for (size_t d = 0; d < config.dense.size(); ++d) {
374     if (!config.dense[d].variable_length) continue;
375     if (dense_feature_last_example[d] == example_index) continue;
376     SparseBuffer& out = (*output_varlen_dense)[d];
377     size_t prev_example_end_index =
378         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
379     out.example_end_indices.push_back(prev_example_end_index);
380   }
381 
382   for (size_t d = 0; d < config.sparse.size(); ++d) {
383     if (sparse_feature_last_example[d] == example_index) continue;
384     SparseBuffer& out = (*output_sparse)[d];
385     size_t prev_example_end_index =
386         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
387     out.example_end_indices.push_back(prev_example_end_index);
388   }
389 
390   return ::tensorflow::OkStatus();
391 }
392 
CountSparseFeatures(const SparseBuffer & sparse_buffer,size_t * total_num_features,size_t * max_num_features)393 void CountSparseFeatures(const SparseBuffer& sparse_buffer,
394                          size_t* total_num_features, size_t* max_num_features) {
395   const std::vector<size_t>& end_indices = sparse_buffer.example_end_indices;
396   *total_num_features += end_indices.back();
397   *max_num_features = std::max(*max_num_features, end_indices[0]);
398   for (size_t i = 1; i < end_indices.size(); ++i) {
399     size_t example_size = end_indices[i] - end_indices[i - 1];
400     *max_num_features = std::max(*max_num_features, example_size);
401   }
402 }
403 
CopySparseBufferToTensor(tf::DataType dtype,size_t offset,SparseBuffer * src,TfLiteTensor * dst)404 void CopySparseBufferToTensor(tf::DataType dtype, size_t offset,
405                               SparseBuffer* src, TfLiteTensor* dst) {
406   switch (dtype) {
407     case tf::DT_INT64: {
408       std::copy(src->int64_list.begin(), src->int64_list.end(),
409                 reinterpret_cast<int64_t*>(dst->data.raw) + offset);
410       break;
411     }
412     case tf::DT_FLOAT: {
413       std::copy(src->float_list.begin(), src->float_list.end(),
414                 reinterpret_cast<float*>(dst->data.raw) + offset);
415       break;
416     }
417     case tf::DT_STRING: {
418       DynamicBuffer buffer;
419       for (auto* begin = src->bytes_list.begin();
420            begin != src->bytes_list.end(); begin++) {
421         buffer.AddString(begin->c_str(), begin->size());
422       }
423       buffer.WriteToTensor(dst, nullptr);
424       break;
425     }
426     default:
427       DCHECK(false) << "Encountered unexpected DataType "
428                     << DataTypeString(dtype)
429                     << "in variable that should have been checked.";
430   }
431 }
432 
CopyToBuffer(tf::gtl::ArraySlice<tstring> vec,char * tensor_buffer,int num_examples,int batch_size,int elements_per_stride)433 inline void CopyToBuffer(tf::gtl::ArraySlice<tstring> vec, char* tensor_buffer,
434                          int num_examples, int batch_size,
435                          int elements_per_stride) {
436   int i = 0, k = 0;
437   int start = 0;
438   for (; i < num_examples; ++i) {
439     for (int j = 0; j < elements_per_stride; ++j) {
440       memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
441       start += vec[k].size();
442       k++;
443     }
444   }
445   // Will happen if the number of examples is less than the desired batch size.
446   for (; i < batch_size; ++i) {
447     for (int j = 0; j < elements_per_stride; ++j) {
448       memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
449       start += vec[k].size();
450       k++;
451     }
452   }
453 }
454 
FastParseExampleLite(const FastParseExampleConfig & config,const TfLiteTensor * serialized,tf::gtl::ArraySlice<tstring> example_names,bool * quick_filter,int quick_filter_size,const std::unique_ptr<ConfigIndex> & config_index,int config_index_size,SeededHasher * hasher,TfLiteResult * result,std::map<absl::string_view,int> & stats,TfLiteContext * context)455 Status FastParseExampleLite(
456     const FastParseExampleConfig& config, const TfLiteTensor* serialized,
457     tf::gtl::ArraySlice<tstring> example_names, bool* quick_filter,
458     int quick_filter_size, const std::unique_ptr<ConfigIndex>& config_index,
459     int config_index_size, SeededHasher* hasher, TfLiteResult* result,
460     std::map<absl::string_view, int>& stats, TfLiteContext* context) {
461   if (result == nullptr) {
462     return tf::errors::Internal("Result is null");
463   }
464   const int count = GetStringCount(serialized);
465   std::vector<tf::Tensor> fixed_dense_values(config.dense.size());
466   std::vector<SparseBuffer> sparse_buffers(config.sparse.size());
467   std::vector<SparseBuffer> varlen_dense_buffers(config.dense.size());
468   Status status_of_minibatch;
469   for (size_t e = 0; e < count; ++e) {
470     Status status_of_minibatch = FastParseSerializedExample(
471         GetString(serialized, e),
472         (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
473         quick_filter, quick_filter_size, config_index, config_index_size,
474         hasher, &result->dense_values, &varlen_dense_buffers, &sparse_buffers,
475         /*arena,*/ stats, result);
476     if (!status_of_minibatch.ok()) break;
477   }
478   if (!status_of_minibatch.ok()) {
479     return status_of_minibatch;
480   }
481   // Merge SparseBuffers from all minibatches for every config.sparse.
482   // auto MergeSparseMinibatches = [&](size_t d) {
483   // Loop over minibatches
484   for (size_t d = 0; d < config.sparse.size(); ++d) {
485     size_t total_num_features = 0;
486     size_t max_num_features = 0;
487     CountSparseFeatures(sparse_buffers[d], &total_num_features,
488                         &max_num_features);
489     tf::TensorShape indices_shape;
490     TfLiteTensor* indices = result->sparse_indices[d];
491     TfLiteTensor* values = result->sparse_values[d];
492 
493     TfLiteTensor* dense_shape = result->sparse_shapes[d];
494     auto* dense_shape_ptr = reinterpret_cast<int64_t*>(dense_shape->data.raw);
495     dense_shape_ptr[1] = max_num_features;
496 
497     TfLiteIntArray* index_shape = TfLiteIntArrayCreate(2);
498     index_shape->data[0] = total_num_features;
499     index_shape->data[1] = 2;
500     context->ResizeTensor(context, indices, index_shape);
501 
502     TfLiteIntArray* output_shape = TfLiteIntArrayCreate(1);
503     output_shape->data[0] = total_num_features;
504     context->ResizeTensor(context, values, output_shape);
505 
506     SparseBuffer& buffer = sparse_buffers[d];
507 
508     // Update indices.
509     auto* indices_p = reinterpret_cast<int64_t*>(indices->data.raw);
510     if (!indices_p) {
511       return tf::errors::Internal("Indices tensor not allocated!");
512     }
513 
514     if (total_num_features > 0) {
515       int64_t* ix_p = indices_p;
516       size_t example_index = 0;
517       int idx0 = 0;
518       size_t delta = 0;
519       for (size_t example_end_index : buffer.example_end_indices) {
520         size_t feature_index = 0;
521         for (; delta < example_end_index; ++delta) {
522           // Column 0: example index
523           if (idx0 < total_num_features) {
524             *ix_p = example_index;
525             // Column 1: the feature index buffer example
526             *(ix_p + 1) = feature_index;
527             ix_p += 2;
528           }
529           ++feature_index;
530           ++idx0;
531         }
532         ++example_index;
533       }
534       CopySparseBufferToTensor(config.sparse[d].dtype, 0, &buffer, values);
535     }
536   }
537 
538   // Merge SparseBuffers from all minibatches for every config.dense having
539   // variable_length.
540   for (size_t d = 0; d < config.dense.size(); ++d) {
541     if (!config.dense[d].variable_length) {
542       continue;
543     }
544     size_t max_num_features = 0;
545     std::vector<size_t>& end_indices =
546         varlen_dense_buffers[d].example_end_indices;
547     max_num_features = std::max(max_num_features, end_indices[0]);
548     for (size_t i = 1; i < end_indices.size(); ++i) {
549       size_t example_size = end_indices[i] - end_indices[i - 1];
550       max_num_features = std::max(max_num_features, example_size);
551     }
552 
553     const size_t stride_size = config.dense[d].elements_per_stride;
554     const size_t max_num_elements = max_num_features / stride_size;
555     tf::TensorShape values_shape;
556     DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
557     const size_t batch_size = GetStringCount(serialized);
558     values_shape.AddDim(batch_size);
559     values_shape.AddDim(max_num_elements);
560     for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
561       values_shape.AddDim(config.dense[d].shape.dim_size(i));
562     }
563     TfLiteTensor* values = result->dense_values[d];
564     const size_t num_elements = GetTensorShape(values).FlatSize();
565 
566     // Nothing to write, exit early.
567     if (num_elements == 0) {
568       continue;
569     }
570 
571     const size_t num_elements_per_minibatch = num_elements / batch_size;
572     switch (config.dense[d].dtype) {
573       case tf::DT_INT64: {
574         FillAndCopyVarLen<int64_t>(d, num_elements, num_elements_per_minibatch,
575                                    config, varlen_dense_buffers, values);
576         break;
577       }
578       case tf::DT_FLOAT: {
579         FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
580                                  config, varlen_dense_buffers, values);
581         break;
582       }
583       default:
584         DCHECK(false) << "Encountered unexpected DataType "
585                       << config.dense[d].dtype
586                       << "in variable that should have been checked";
587     }
588   }
589 
590   // Merge tflite string buffers if necessary.
591   for (size_t d = 0; d < config.dense.size(); ++d) {
592     if (config.dense[d].variable_length) {
593       continue;
594     }
595     if (result->dense_values[d]->type == kTfLiteString) {
596       auto& in = result->dense_tensors[d];
597       auto vec = in.vec<tstring>();
598       const int batch_size = result->dense_values[d]->dims->data[0];
599       const int elements_per_stride = config.dense[d].elements_per_stride;
600       int total_size = 0;
601       std::vector<int32_t> offsets;
602       offsets.reserve(vec.size() + 1);
603       offsets.push_back(0);
604       int k = 0;
605       for (int i = 0; i < batch_size; ++i) {
606         for (int j = 0; j < elements_per_stride; ++j) {
607           if (i < count) {
608             total_size += vec(k++).size();
609             offsets.push_back(total_size);
610           } else {
611             offsets.push_back(total_size);
612           }
613         }
614       }
615       const int32_t num_strings = offsets.size() - 1;
616       const size_t required_bytes = sizeof(int32_t) * (num_strings + 2) +
617           total_size;
618       char* tensor_buffer =
619           reinterpret_cast<char*>(result->dense_values[d]->data.raw);
620       if (result->dense_values[d]->bytes < required_bytes) {
621         if (result->dense_values[d]->data.raw) {
622           free(result->dense_values[d]->data.raw);
623         }
624         tensor_buffer = reinterpret_cast<char*>(malloc(required_bytes));
625         result->dense_values[d]->data.raw = tensor_buffer;
626         result->dense_values[d]->bytes = required_bytes;
627       }
628       const int32_t start = sizeof(int32_t) * (num_strings + 2);
629       memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
630       for (size_t i = 0; i < offsets.size(); i++) {
631         int32_t offset_i = start + offsets[i];
632         memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset_i,
633                sizeof(int32_t));
634       }
635       tf::gtl::ArraySlice<tstring> slice(vec.data(), vec.size());
636       CopyToBuffer(slice, tensor_buffer + start, count, batch_size,
637                    elements_per_stride);
638     }
639   }
640   return ::tensorflow::OkStatus();
641 }
642 
643 }  // namespace
644 
645 enum InputTensor {
646   kExampleTensor = 0,
647   kNamesTensor = 1,
648   kSparseKeysTensor = 2,
649   kDenseKeysTensor = 3,
650   kRaggedKeysTensor = 4,
651 };
652 
653 struct OpData {
654   FastParseExampleConfig config;
655   std::vector<tf::TensorShape> dense_shapes;
656   int dense_size = 0;
657   int sparse_size = 0;
658   std::unique_ptr<ConfigIndex> config_index;
659   int config_index_size;
660   SeededHasher hasher;
661   TfLiteResult got;
662   bool* quick_filter = nullptr;
663   int quick_filter_size;
664   bool created = false;
~OpDatatflite::ops::custom::parse_example::OpData665   ~OpData() {
666     if (quick_filter) {
667       free(quick_filter);
668     }
669   }
670 };
671 
Init(TfLiteContext * context,const char * buffer,size_t length)672 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
673   return new OpData;
674 }
675 
676 template <typename T>
AsTensor(const std::vector<T> & val)677 tf::Tensor AsTensor(const std::vector<T>& val) {
678   tf::Tensor ret(tf::DataTypeToEnum<T>::value,
679                  {static_cast<int64_t>(val.size())});
680   std::copy_n(val.begin(), val.size(), ret.flat<T>().data());
681   return ret;
682 }
683 
684 enum Version {
685   V1,
686   V2,
687 };
688 
TfLiteToTfShape(TfLiteIntArray * array)689 tf::TensorShape TfLiteToTfShape(TfLiteIntArray* array) {
690   tf::TensorShape shape;
691   for (int i = 0; i < array->size; i++) {
692     shape.AddDim(array->data[i]);
693   }
694   return shape;
695 }
696 
697 template <Version version>
PrepareParseExample(TfLiteContext * context,TfLiteNode * node)698 TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
699   OpData* data = reinterpret_cast<OpData*>(node->user_data);
700   TF_LITE_ENSURE(context, node->custom_initial_data);
701   data->config.dense.clear();
702   data->config.sparse.clear();
703   data->got.dense_values.clear();
704   const flexbuffers::Vector& v =
705       flexbuffers::GetRoot(
706           reinterpret_cast<const uint8_t*>(node->custom_initial_data),
707           node->custom_initial_data_size)
708           .AsVector();
709   if (v.size() == 2) {
710     tf::NodeDef nodedef;
711     TF_LITE_ENSURE_EQ(context, nodedef.ParseFromString(v[1].AsString().str()),
712                       true);
713     if (version == V1) {
714       data->dense_size = nodedef.attr().at("Ndense").i();
715       data->sparse_size = nodedef.attr().at("Nsparse").i();
716     } else if (version == V2) {
717       data->dense_size = nodedef.attr().at("Tdense").list().type_size();
718       data->sparse_size = nodedef.attr().at("num_sparse").i();
719     }
720     auto dense_shapes = nodedef.attr().at("dense_shapes").list();
721     if (data->dense_shapes.empty()) {
722       for (int i = 0; i < dense_shapes.shape_size(); ++i) {
723         data->dense_shapes.push_back(dense_shapes.shape(i));
724       }
725     }
726   } else {
727     const flexbuffers::Map& m =
728         flexbuffers::GetRoot(
729             reinterpret_cast<const uint8_t*>(node->custom_initial_data),
730             node->custom_initial_data_size)
731             .AsMap();
732     const flexbuffers::TypedVector keys = m.Keys();
733     int num_sparse = 0;
734     int num_dense = 0;
735     for (int k = 0; k < keys.size(); ++k) {
736       const std::string key = keys[k].ToString();
737       const auto value = m[key];
738       if (key == "Nsparse" || key == "num_sparse") {
739         num_sparse = value.AsInt32();
740       }
741       if (key == "Ndense") {
742         num_dense = value.AsInt32();
743       }
744     }
745     data->sparse_size = num_sparse;
746     data->dense_size = num_dense;
747     if (version == V2) {
748       const TfLiteTensor* dense_key_tensor =
749           GetInput(context, node, kDenseKeysTensor);
750       data->dense_size = GetTensorShape(dense_key_tensor).FlatSize();
751     }
752   }
753 
754   data->config.dense.reserve(data->dense_size);
755   data->config.sparse.reserve(data->sparse_size);
756   data->dense_shapes.reserve(data->dense_size);
757   const auto* serialized = GetInput(context, node, 0);
758   const int batch_size =
759       serialized->dims->size > 0 ? serialized->dims->data[0] : 1;
760   const bool missing_shape_info = data->dense_shapes.empty();
761   for (int i = 0; i < data->dense_size; i++) {
762     TfLiteTensor* dense_key_tensor =
763         GetOutput(context, node, data->sparse_size * 3 + i);
764     TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
765     if (missing_shape_info) {
766       data->dense_shapes.push_back(TfLiteToTfShape(output_size));
767     }
768     // use original tflite tensor size if inputs are resized.
769     const int original_size = data->dense_shapes[i].dims() > 0
770                                   ? data->dense_shapes[i].dim_size(0)
771                                   : 1;
772     output_size->data[0] = batch_size * original_size;
773     context->ResizeTensor(context, dense_key_tensor, output_size);
774   }
775 
776   size_t offset = 0;
777   for (int i = 0; i < data->sparse_size; i++) {
778     auto* parse_output = GetOutput(context, node, i + offset);
779     SetTensorToDynamic(parse_output);
780     TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(2);
781     sparse_size->data[0] = batch_size;
782     sparse_size->data[1] = 2;
783     context->ResizeTensor(context, parse_output, sparse_size);
784     data->got.sparse_indices.push_back(parse_output);
785   }
786   offset += data->sparse_size;
787   for (int i = 0; i < data->sparse_size; i++) {
788     auto* parse_output = GetOutput(context, node, i + offset);
789     SetTensorToDynamic(parse_output);
790     TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
791     sparse_size->data[0] = 0;
792     context->ResizeTensor(context, parse_output, sparse_size);
793     data->got.sparse_values.push_back(parse_output);
794   }
795   offset += data->sparse_size;
796   for (int i = 0; i < data->sparse_size; i++) {
797     TfLiteTensor* parse_output = GetOutput(context, node, i + offset);
798     SetTensorToDynamic(parse_output);
799     TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
800     sparse_size->data[0] = 2;
801     context->ResizeTensor(context, parse_output, sparse_size);
802     auto* shapes_shape_t = reinterpret_cast<int64_t*>(parse_output->data.i64);
803     shapes_shape_t[0] = batch_size;
804     shapes_shape_t[1] = 1;
805     data->got.sparse_shapes.push_back(parse_output);
806   }
807   data->created = false;
808   return kTfLiteOk;
809 }
810 
811 template <Version version>
EvalParseExample(TfLiteContext * context,TfLiteNode * node)812 TfLiteStatus EvalParseExample(TfLiteContext* context, TfLiteNode* node) {
813   OpData* data = reinterpret_cast<OpData*>(node->user_data);
814   if (!data->created) {
815     for (int i = 0; i < data->sparse_size; i++) {
816       int input_index =
817           version == V1 ? kSparseKeysTensor + i : kSparseKeysTensor;
818       int string_index = version == V1 ? 0 : i;
819       const TfLiteTensor* sparse_key_tensor =
820           GetInput(context, node, input_index);
821       const auto key = GetString(sparse_key_tensor, string_index);
822       const auto* sparse_output =
823           GetOutput(context, node, i + data->sparse_size);
824       std::string k(key.str, key.len);
825       switch (sparse_output->type) {
826         case kTfLiteInt64:
827           data->config.sparse.emplace_back(k,
828                                            tf::DataTypeToEnum<int64_t>::value);
829           break;
830         case kTfLiteFloat32:
831           data->config.sparse.emplace_back(k, tf::DataTypeToEnum<float>::value);
832           break;
833         case kTfLiteString:
834           data->config.sparse.emplace_back(k,
835                                            tf::DataTypeToEnum<tstring>::value);
836           break;
837         default:
838           return kTfLiteError;
839       }
840     }
841 
842     const auto& dense_shapes = data->dense_shapes;
843     for (int i = 0; i < data->dense_size; i++) {
844       const int input_index = version == V1
845                                   ? kSparseKeysTensor + data->sparse_size + i
846                                   : kDenseKeysTensor;
847       const int dense_defaults_index =
848           version == V1
849               ? kSparseKeysTensor + data->sparse_size + data->dense_size + i
850               : kRaggedKeysTensor + i + 1;
851       int string_index = version == V1 ? 0 : i;
852       const TfLiteTensor* dense_key_tensor =
853           GetInput(context, node, input_index);
854       const auto* dense_output =
855           GetOutput(context, node, i + data->sparse_size * 3);
856       const auto* dense_defaults =
857           GetInput(context, node, dense_defaults_index);
858       const auto key = GetString(dense_key_tensor, string_index);
859       std::string k(key.str, key.len);
860       const int elements_per_stride =
861           dense_shapes[i].dims() ? dense_shapes[i].num_elements() : 1;
862       switch (dense_output->type) {
863         case kTfLiteInt64:
864           data->config.dense.emplace_back(
865               k, tf::DataTypeToEnum<int64_t>::value, dense_shapes[i],
866               AsTensor<int64_t>(std::vector<int64_t>(
867                   dense_defaults->data.i64,
868                   dense_defaults->data.i64 + elements_per_stride)),
869               false, elements_per_stride);
870           break;
871         case kTfLiteFloat32:
872           data->config.dense.emplace_back(
873               k, tf::DataTypeToEnum<float>::value, dense_shapes[i],
874               AsTensor<float>(std::vector<float>(
875                   dense_defaults->data.f,
876                   dense_defaults->data.f + elements_per_stride)),
877               false, elements_per_stride);
878           break;
879         case kTfLiteString: {
880           const int num_strings = GetStringCount(dense_defaults);
881           std::vector<tstring> values;
882           for (int i = 0; i < num_strings; ++i) {
883             auto ref = GetString(dense_defaults, i);
884             values.emplace_back(ref.str, ref.len);
885           }
886           data->config.dense.emplace_back(
887               k, tf::DataTypeToEnum<tstring>::value, dense_shapes[i],
888               AsTensor<tstring>(values), false, elements_per_stride);
889           break;
890         }
891         default:
892           return kTfLiteError;
893       }
894     }
895 
896     int offset = 3 * data->sparse_size;
897     for (int i = 0; i < data->dense_size; i++) {
898       auto* parse_output = GetOutput(context, node, i + offset);
899       data->got.dense_values.push_back(parse_output);
900       if (parse_output->type == kTfLiteString) {
901         tf::TensorShape shape;
902         if (parse_output->dims->size == 1) {
903           shape.AddDim(parse_output->dims->data[0]);
904         } else {
905           shape.AddDim(GetTensorShape(parse_output).FlatSize());
906         }
907         data->got.dense_tensors[i] =
908             tf::Tensor(tf::DataTypeToEnum<tstring>::value, shape);
909       }
910     }
911 
912     size_t config_size = data->config.dense.size();
913     config_size += data->config.sparse.size();
914     data->config_index_size = config_size;
915     auto config_index = std::make_unique<ConfigIndex>(config_size);
916     bool ok = true;
917     int max_length = 0;
918     for (size_t d = 0; d < data->config.dense.size(); ++d) {
919       auto s = data->config.dense[d].feature_name;
920       max_length = s.length() > max_length ? s.length() : max_length;
921     }
922     for (size_t d = 0; d < data->config.sparse.size(); ++d) {
923       auto s = data->config.sparse[d].feature_name;
924       max_length = s.length() > max_length ? s.length() : max_length;
925     }
926     if (data->quick_filter) {
927       free(data->quick_filter);
928     }
929     data->quick_filter =
930         static_cast<bool*>(malloc(++max_length * sizeof(bool)));
931     memset(data->quick_filter, 0, max_length * sizeof(bool));
932     data->quick_filter_size = max_length;
933     for (size_t d = 0; d < data->config.dense.size(); ++d) {
934       const auto& s = data->config.dense[d].feature_name;
935       data->quick_filter[s.length()] = true;
936     }
937     for (size_t d = 0; d < data->config.sparse.size(); ++d) {
938       const auto& s = data->config.sparse[d].feature_name;
939       data->quick_filter[s.length()] = true;
940     }
941 
942     for (int i = 0; i < 1000; ++i) {
943       for (size_t d = 0; d < data->config.dense.size(); ++d) {
944         ok &= config_index->InsertUnique(
945             data->hasher(data->config.dense[d].feature_name), {d, Type::Dense});
946       }
947       for (size_t d = 0; d < data->config.sparse.size(); ++d) {
948         ok &= config_index->InsertUnique(
949             data->hasher(data->config.sparse[d].feature_name),
950             {d, Type::Sparse});
951       }
952       if (ok) {
953         break;
954       }
955       data->hasher.seed++;
956       config_index->Clear(config_size);
957       ok = true;
958     }
959     if (!ok) {
960       return kTfLiteError;
961     }
962     data->config_index = std::move(config_index);
963     data->created = true;
964   }
965 
966   const TfLiteTensor* serialized = GetInput(context, node, kExampleTensor);
967 
968   std::map<absl::string_view, int> stats;
969   const auto status = FastParseExampleLite(
970       data->config, serialized, {}, data->quick_filter, data->quick_filter_size,
971       data->config_index, data->config_index_size, &data->hasher, &data->got,
972       stats, context);
973   if (status != ::tensorflow::OkStatus()) {
974     TF_LITE_KERNEL_LOG(context, status.ToString().c_str());
975     return kTfLiteError;
976   }
977   return kTfLiteOk;
978 }
979 
Free(TfLiteContext * context,void * buffer)980 void Free(TfLiteContext* context, void* buffer) {
981   auto* obj = reinterpret_cast<OpData*>(buffer);
982   delete obj;
983 }
984 
985 }  // namespace parse_example
986 
Register_PARSE_EXAMPLE()987 TfLiteRegistration* Register_PARSE_EXAMPLE() {
988   static TfLiteRegistration r = {
989       parse_example::Init, parse_example::Free,
990       parse_example::PrepareParseExample<parse_example::V1>,
991       parse_example::EvalParseExample<parse_example::V1>};
992   return &r;
993 }
994 
Register_PARSE_EXAMPLE_V2()995 TfLiteRegistration* Register_PARSE_EXAMPLE_V2() {
996   static TfLiteRegistration r = {
997       parse_example::Init, parse_example::Free,
998       parse_example::PrepareParseExample<parse_example::V2>,
999       parse_example::EvalParseExample<parse_example::V2>};
1000   return &r;
1001 }
1002 
AddParseExampleOp(::tflite::MutableOpResolver * resolver)1003 extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver) {
1004   resolver->AddCustom("ParseExample", Register_PARSE_EXAMPLE());
1005   resolver->AddCustom("ParseExampleV2", Register_PARSE_EXAMPLE_V2());
1006 }
1007 
1008 }  // namespace custom
1009 }  // namespace ops
1010 }  // namespace tflite
1011