• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/fingerprint.h"
30 #include "tensorflow/core/public/session_options.h"
31 #include "tensorflow/core/util/example_proto_fast_parsing.h"
32 #include "tensorflow/core/util/presized_cuckoo_map.h"
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/tensor.h"
35 #include "tensorflow/lite/kernels/kernel_util.h"
36 #include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
37 #include "tensorflow/lite/mutable_op_resolver.h"
38 #include "tensorflow/lite/string_util.h"
39 
40 namespace tflite {
41 namespace ops {
42 namespace custom {
43 namespace parse_example {
44 namespace {
45 
46 namespace tf = ::tensorflow;
47 using tf::Status;
48 using tf::StringPiece;
49 using tf::tstring;
50 using tf::example::CopyOrMoveBlock;
51 using tf::example::FastParseExampleConfig;
52 using tf::example::GetListFromBuffer;
53 using tf::example::LimitedArraySlice;
54 using tf::example::ParseExample;
55 using tf::example::SeededHasher;
56 using tf::example::SmallVector;
57 using tf::example::SparseBuffer;
58 using tf::example::Type;
59 using tf::example::parsed::Example;
60 
61 using ConfigIndex = tf::PresizedCuckooMap<std::pair<int32_t, Type>>;
62 
63 struct TfLiteResult {
64   std::vector<TfLiteTensor*> dense_values;
65   std::vector<TfLiteTensor*> sparse_values;
66   std::vector<TfLiteTensor*> sparse_indices;
67   std::vector<TfLiteTensor*> sparse_shapes;
68   std::map<int, tf::Tensor> dense_tensors;
69 };
70 
71 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const FastParseExampleConfig & config,std::vector<SparseBuffer> & varlen_dense_buffers,TfLiteTensor * values)72 void FillAndCopyVarLen(const int d, const size_t num_elements,
73                        const size_t num_elements_per_minibatch,
74                        const FastParseExampleConfig& config,
75                        std::vector<SparseBuffer>& varlen_dense_buffers,
76                        TfLiteTensor* values) {
77   const tf::Tensor& default_value = config.dense[d].default_value;
78 
79   // Copy-fill the tensors (creating the zero/fill-padding)
80   std::fill(reinterpret_cast<T*>(values->data.raw),
81             reinterpret_cast<T*>(values->data.raw) + num_elements,
82             default_value.flat<T>()(0));
83 
84   auto data = reinterpret_cast<T*>(values->data.raw);
85 
86   const SparseBuffer& buffer = varlen_dense_buffers[d];
87   // Number of examples being stored in this buffer
88   const auto& end_indices = buffer.example_end_indices;
89   const size_t examples_in_buffer = end_indices.size();
90 
91   const auto& list = GetListFromBuffer<T>(buffer);
92   auto list_ptr = list.begin();
93 
94   size_t elements_tally = 0;
95   // Iterate through all the examples stored in this buffer.
96   for (size_t j = 0; j < examples_in_buffer; ++j) {
97     // Number of elements stored for this example.
98     const size_t num_elems = end_indices[j] - elements_tally;
99     CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
100     // Move forward this many elements in the varlen buffer.
101     list_ptr += num_elems;
102     // Move forward to the next minibatch entry in the values output.
103     data += num_elements_per_minibatch;
104     elements_tally = end_indices[j];
105   }
106   DCHECK(elements_tally == list.size());
107 }
108 
ParseExample(StringRef serialized,Example * example)109 bool ParseExample(StringRef serialized, Example* example) {
110   DCHECK(example != nullptr);
111   tf::protobuf::io::CodedInputStream stream(
112       reinterpret_cast<const uint8*>(serialized.str), serialized.len);
113   tensorflow::example::EnableAliasing(&stream);
114   return ParseExample(&stream, example);
115 }
116 
FastParseSerializedExample(StringRef serialized_example,const tstring & example_name,const size_t example_index,const FastParseExampleConfig & config,bool * quick_filter,int quick_filter_size,const std::unique_ptr<ConfigIndex> & config_index,int config_index_size,SeededHasher * hasher,std::vector<TfLiteTensor * > * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::map<absl::string_view,int> & stats,TfLiteResult * result)117 Status FastParseSerializedExample(
118     StringRef serialized_example, const tstring& example_name,
119     const size_t example_index, const FastParseExampleConfig& config,
120     bool* quick_filter, int quick_filter_size,
121     const std::unique_ptr<ConfigIndex>& config_index, int config_index_size,
122     SeededHasher* hasher, std::vector<TfLiteTensor*>* output_dense,
123     std::vector<SparseBuffer>* output_varlen_dense,
124     std::vector<SparseBuffer>* output_sparse,
125     std::map<absl::string_view, int>& stats, TfLiteResult* result) {
126   DCHECK(output_dense != nullptr);
127   tensorflow::example::parsed::Example parsed_example;
128   if (!ParseExample(serialized_example, &parsed_example)) {
129     return tf::errors::Internal("Failed to parse example");
130   }
131   std::vector<tf::int64> dense_feature_last_example(config.dense.size(), -1);
132   std::vector<tf::int64> sparse_feature_last_example(config.sparse.size(), -1);
133   // Handle features present in the example.
134   const size_t parsed_example_size = parsed_example.size();
135   for (size_t i = 0; i < parsed_example_size; ++i) {
136     // This is a logic that standard protobuf parsing is implementing.
137     // I.e. last entry in the map overwrites all the previous ones.
138     tensorflow::example::parsed::FeatureMapEntry& name_and_feature =
139         parsed_example[parsed_example_size - i - 1];
140     const StringPiece feature_name = name_and_feature.first;
141     tensorflow::example::parsed::Feature& feature = name_and_feature.second;
142     if (feature_name.length() >= quick_filter_size ||
143         !quick_filter[feature_name.length()]) {
144       continue;
145     }
146     const uint64_t h = (*hasher)(feature_name);
147     std::pair<int32_t, Type> d_and_type;
148     if (!config_index->Find(h, &d_and_type)) {
149       continue;
150     }
151     size_t d = d_and_type.first;
152     bool is_dense = d_and_type.second == Type::Dense;
153 
154     auto example_error = [&](StringPiece suffix) {
155       return tf::errors::Internal("Name: ", example_name,
156                                   ", Key: ", feature_name,
157                                   ", Index: ", example_index, ".  ", suffix);
158     };
159 
160     auto parse_error = [&] {
161       return example_error("Can't parse serialized Example.");
162     };
163 
164     tf::DataType example_dtype;
165     if (feature.ParseDataType(&example_dtype) != Status::OK()) {
166       return parse_error();
167     }
168     if (is_dense) {
169       if (example_dtype == tf::DT_INVALID) continue;
170 
171       dense_feature_last_example[d] = example_index;
172 
173       if (example_dtype != config.dense[d].dtype) {
174         return example_error(absl::StrCat(
175             "Data types don't match. Data type: ",
176             DataTypeString(example_dtype),
177             " but expected type: ", DataTypeString(config.dense[d].dtype)));
178       }
179       if (!config.dense[d].variable_length) {
180         TfLiteTensor* out = (*output_dense)[d];
181 
182         const std::size_t num_elements = config.dense[d].elements_per_stride;
183         const std::size_t offset = example_index * num_elements;
184 
185         auto shape_error = [&](size_t size, StringPiece type_str) {
186           return example_error(absl::StrCat(
187               "Number of ", type_str,
188               " values != expected.  "
189               "Values size:",
190               size,
191               " but output shape: ", config.dense[d].shape.DebugString()));
192         };
193 
194         switch (config.dense[d].dtype) {
195           case tf::DT_INT64: {
196             auto out_p = reinterpret_cast<tf::int64*>(out->data.raw) + offset;
197             LimitedArraySlice<tf::int64> slice(out_p, num_elements);
198             if (!feature.ParseInt64List(&slice)) return parse_error();
199             if (slice.EndDistance() != 0) {
200               return shape_error(num_elements - slice.EndDistance(), "int64");
201             }
202             break;
203           }
204           case tf::DT_FLOAT: {
205             auto out_p = reinterpret_cast<float*>(out->data.raw) + offset;
206             LimitedArraySlice<float> slice(out_p, num_elements);
207             if (!feature.ParseFloatList(&slice)) return parse_error();
208             if (slice.EndDistance() != 0) {
209               return shape_error(num_elements - slice.EndDistance(), "float");
210             }
211             break;
212           }
213           case tf::DT_STRING: {
214             auto& out_tensor = result->dense_tensors[d];
215             auto out_p = out_tensor.flat<tstring>().data() + offset;
216             LimitedArraySlice<tstring> slice(out_p, num_elements);
217             if (!feature.ParseBytesList(&slice)) return parse_error();
218             if (slice.EndDistance() != 0) {
219               return shape_error(num_elements - slice.EndDistance(), "bytes");
220             }
221             break;
222           }
223           default:
224             return tf::errors::Internal("Unrecognized dense type: ",
225                                         config.dense[d].dtype);
226         }
227       } else {  // if dense variable length
228         SparseBuffer& out = (*output_varlen_dense)[d];
229 
230         const std::size_t num_elements = config.dense[d].elements_per_stride;
231 
232         if (example_dtype != tf::DT_INVALID &&
233             example_dtype != config.dense[d].dtype) {
234           return example_error(absl::StrCat(
235               "Data types don't match. ",
236               "Expected type: ", DataTypeString(config.dense[d].dtype)));
237         }
238 
239         auto shape_error = [&](size_t size, StringPiece type_str) {
240           return example_error(
241               absl::StrCat("Number of ", type_str,
242                            " values is not a multiple of stride length. Saw ",
243                            size, " values but output shape is: ",
244                            config.dense[d].shape.DebugString()));
245         };
246 
247         switch (config.dense[d].dtype) {
248           case tf::DT_INT64: {
249             if (example_dtype != tf::DT_INVALID) {
250               if (!feature.ParseInt64List(&out.int64_list)) {
251                 return parse_error();
252               }
253               if (out.int64_list.size() % num_elements != 0) {
254                 return shape_error(out.int64_list.size(), "int64");
255               }
256             }
257             out.example_end_indices.push_back(out.int64_list.size());
258             break;
259           }
260           case tf::DT_FLOAT: {
261             if (example_dtype != tf::DT_INVALID) {
262               if (!feature.ParseFloatList(&out.float_list)) {
263                 return parse_error();
264               }
265               if (out.float_list.size() % num_elements != 0) {
266                 return shape_error(out.float_list.size(), "float");
267               }
268             }
269             out.example_end_indices.push_back(out.float_list.size());
270             break;
271           }
272           case tf::DT_STRING: {
273             if (example_dtype != tf::DT_INVALID) {
274               if (!feature.ParseBytesList(&out.bytes_list)) {
275                 return parse_error();
276               }
277               if (out.bytes_list.size() % num_elements != 0) {
278                 return shape_error(out.bytes_list.size(), "byte");
279               }
280             }
281             out.example_end_indices.push_back(out.bytes_list.size());
282             break;
283           }
284           default:
285             return tf::errors::Internal("Should not happen: ",
286                                         config.dense[d].dtype);
287         }
288       }
289     } else {
290       // is sparse or ragged
291       auto& last_example = sparse_feature_last_example;
292       if (last_example[d] == example_index) {
293         continue;
294       }
295       last_example[d] = example_index;
296       SparseBuffer& out = (*output_sparse)[d];
297       tf::DataType feature_dtype = config.sparse[d].dtype;
298       if (example_dtype != tf::DT_INVALID && example_dtype != feature_dtype) {
299         return tf::errors::Internal("Data types don't match:", example_dtype,
300                                     " != ", feature_dtype);
301       }
302       switch (feature_dtype) {
303         case tf::DT_INT64: {
304           if (example_dtype != tf::DT_INVALID) {
305             if (!feature.ParseInt64List(&out.int64_list)) {
306               return parse_error();
307             }
308           }
309           out.example_end_indices.push_back(out.int64_list.size());
310           break;
311         }
312         case tf::DT_FLOAT: {
313           if (example_dtype != tf::DT_INVALID) {
314             if (!feature.ParseFloatList(&out.float_list)) {
315               return parse_error();
316             }
317           }
318           out.example_end_indices.push_back(out.float_list.size());
319           break;
320         }
321         case tf::DT_STRING: {
322           if (example_dtype != tf::DT_INVALID) {
323             if (!feature.ParseBytesList(&out.bytes_list)) {
324               return parse_error();
325             }
326           }
327           out.example_end_indices.push_back(out.bytes_list.size());
328           break;
329         }
330         default:
331           return tf::errors::Internal("Should not happen: ", feature_dtype);
332       }
333     }
334   }
335   // Handle missing dense features for fixed strides.
336   for (size_t d = 0; d < config.dense.size(); ++d) {
337     if (config.dense[d].variable_length) continue;
338     if (dense_feature_last_example[d] == example_index) continue;
339     if (config.dense[d].default_value.NumElements() == 0) {
340       return tf::errors::Internal(
341           "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
342           " (data type: ", DataTypeString(config.dense[d].dtype), ")",
343           " is required but could not be found.");
344     }
345     const tf::Tensor& in = config.dense[d].default_value;
346     TfLiteTensor* out = result->dense_values[d];
347     const std::size_t num_elements = in.shape().num_elements();
348     const std::size_t offset = example_index * num_elements;
349     switch (config.dense[d].dtype) {
350       case tf::DT_INT64: {
351         std::copy_n(in.flat<tf::int64>().data(), num_elements,
352                     out->data.i64 + offset);
353         break;
354       }
355       case tf::DT_FLOAT: {
356         std::copy_n(in.flat<float>().data(), num_elements,
357                     out->data.f + offset);
358         break;
359       }
360       case tf::DT_STRING: {
361         auto& out_tensor = result->dense_tensors[d];
362         std::copy_n(in.flat<tstring>().data(), num_elements,
363                     out_tensor.flat<tstring>().data() + offset);
364         break;
365       }
366       default:
367         return tf::errors::Internal("Should not happen: ",
368                                     config.dense[d].dtype);
369     }
370   }
371   for (size_t d = 0; d < config.dense.size(); ++d) {
372     if (!config.dense[d].variable_length) continue;
373     if (dense_feature_last_example[d] == example_index) continue;
374     SparseBuffer& out = (*output_varlen_dense)[d];
375     size_t prev_example_end_index =
376         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
377     out.example_end_indices.push_back(prev_example_end_index);
378   }
379 
380   for (size_t d = 0; d < config.sparse.size(); ++d) {
381     if (sparse_feature_last_example[d] == example_index) continue;
382     SparseBuffer& out = (*output_sparse)[d];
383     size_t prev_example_end_index =
384         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
385     out.example_end_indices.push_back(prev_example_end_index);
386   }
387 
388   return Status::OK();
389 }
390 
CountSparseFeatures(const SparseBuffer & sparse_buffer,size_t * total_num_features,size_t * max_num_features)391 void CountSparseFeatures(const SparseBuffer& sparse_buffer,
392                          size_t* total_num_features, size_t* max_num_features) {
393   const std::vector<size_t>& end_indices = sparse_buffer.example_end_indices;
394   *total_num_features += end_indices.back();
395   *max_num_features = std::max(*max_num_features, end_indices[0]);
396   for (size_t i = 1; i < end_indices.size(); ++i) {
397     size_t example_size = end_indices[i] - end_indices[i - 1];
398     *max_num_features = std::max(*max_num_features, example_size);
399   }
400 }
401 
CopySparseBufferToTensor(tf::DataType dtype,size_t offset,SparseBuffer * src,TfLiteTensor * dst)402 void CopySparseBufferToTensor(tf::DataType dtype, size_t offset,
403                               SparseBuffer* src, TfLiteTensor* dst) {
404   switch (dtype) {
405     case tf::DT_INT64: {
406       std::copy(src->int64_list.begin(), src->int64_list.end(),
407                 reinterpret_cast<int64_t*>(dst->data.raw) + offset);
408       break;
409     }
410     case tf::DT_FLOAT: {
411       std::copy(src->float_list.begin(), src->float_list.end(),
412                 reinterpret_cast<float*>(dst->data.raw) + offset);
413       break;
414     }
415     case tf::DT_STRING: {
416       DynamicBuffer buffer;
417       for (auto* begin = src->bytes_list.begin();
418            begin != src->bytes_list.end(); begin++) {
419         buffer.AddString(begin->c_str(), begin->size());
420       }
421       buffer.WriteToTensor(dst, nullptr);
422       break;
423     }
424     default:
425       DCHECK(false) << "Encountered unexpected DataType "
426                     << DataTypeString(dtype)
427                     << "in variable that should have been checked.";
428   }
429 }
430 
CopyToBuffer(tf::gtl::ArraySlice<tstring> vec,char * tensor_buffer,int num_examples,int batch_size,int elements_per_stride)431 inline void CopyToBuffer(tf::gtl::ArraySlice<tstring> vec, char* tensor_buffer,
432                          int num_examples, int batch_size,
433                          int elements_per_stride) {
434   int i = 0, k = 0;
435   int start = 0;
436   for (; i < num_examples; ++i) {
437     for (int j = 0; j < elements_per_stride; ++j) {
438       memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
439       start += vec[k].size();
440       k++;
441     }
442   }
443   // Will happen if the number of examples is less than the desired batch size.
444   for (; i < batch_size; ++i) {
445     for (int j = 0; j < elements_per_stride; ++j) {
446       memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
447       start += vec[k].size();
448       k++;
449     }
450   }
451 }
452 
FastParseExampleLite(const FastParseExampleConfig & config,const TfLiteTensor * serialized,tf::gtl::ArraySlice<tstring> example_names,bool * quick_filter,int quick_filter_size,const std::unique_ptr<ConfigIndex> & config_index,int config_index_size,SeededHasher * hasher,TfLiteResult * result,std::map<absl::string_view,int> & stats,TfLiteContext * context)453 Status FastParseExampleLite(
454     const FastParseExampleConfig& config, const TfLiteTensor* serialized,
455     tf::gtl::ArraySlice<tstring> example_names, bool* quick_filter,
456     int quick_filter_size, const std::unique_ptr<ConfigIndex>& config_index,
457     int config_index_size, SeededHasher* hasher, TfLiteResult* result,
458     std::map<absl::string_view, int>& stats, TfLiteContext* context) {
459   if (result == nullptr) {
460     return tf::errors::Internal("Result is null");
461   }
462   const int count = GetStringCount(serialized);
463   std::vector<tf::Tensor> fixed_dense_values(config.dense.size());
464   std::vector<SparseBuffer> sparse_buffers(config.sparse.size());
465   std::vector<SparseBuffer> varlen_dense_buffers(config.dense.size());
466   Status status_of_minibatch;
467   for (size_t e = 0; e < count; ++e) {
468     Status status_of_minibatch = FastParseSerializedExample(
469         GetString(serialized, e),
470         (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
471         quick_filter, quick_filter_size, config_index, config_index_size,
472         hasher, &result->dense_values, &varlen_dense_buffers, &sparse_buffers,
473         /*arena,*/ stats, result);
474     if (!status_of_minibatch.ok()) break;
475   }
476   if (!status_of_minibatch.ok()) {
477     return status_of_minibatch;
478   }
479   // Merge SparseBuffers from all minibatches for every config.sparse.
480   // auto MergeSparseMinibatches = [&](size_t d) {
481   // Loop over minibatches
482   for (size_t d = 0; d < config.sparse.size(); ++d) {
483     size_t total_num_features = 0;
484     size_t max_num_features = 0;
485     CountSparseFeatures(sparse_buffers[d], &total_num_features,
486                         &max_num_features);
487     tf::TensorShape indices_shape;
488     TfLiteTensor* indices = result->sparse_indices[d];
489     TfLiteTensor* values = result->sparse_values[d];
490 
491     TfLiteTensor* dense_shape = result->sparse_shapes[d];
492     auto* dense_shape_ptr = reinterpret_cast<int64_t*>(dense_shape->data.raw);
493     dense_shape_ptr[1] = max_num_features;
494 
495     TfLiteIntArray* index_shape = TfLiteIntArrayCreate(2);
496     index_shape->data[0] = total_num_features;
497     index_shape->data[1] = 2;
498     context->ResizeTensor(context, indices, index_shape);
499 
500     TfLiteIntArray* output_shape = TfLiteIntArrayCreate(1);
501     output_shape->data[0] = total_num_features;
502     context->ResizeTensor(context, values, output_shape);
503 
504     SparseBuffer& buffer = sparse_buffers[d];
505 
506     // Update indices.
507     auto* indices_p = reinterpret_cast<int64_t*>(indices->data.raw);
508     if (!indices_p) {
509       return tf::errors::Internal("Indices tensor not allocated!");
510     }
511 
512     if (total_num_features > 0) {
513       int64_t* ix_p = indices_p;
514       size_t example_index = 0;
515       int idx0 = 0;
516       size_t delta = 0;
517       for (size_t example_end_index : buffer.example_end_indices) {
518         size_t feature_index = 0;
519         for (; delta < example_end_index; ++delta) {
520           // Column 0: example index
521           if (idx0 < total_num_features) {
522             *ix_p = example_index;
523             // Column 1: the feature index buffer example
524             *(ix_p + 1) = feature_index;
525             ix_p += 2;
526           }
527           ++feature_index;
528           ++idx0;
529         }
530         ++example_index;
531       }
532       CopySparseBufferToTensor(config.sparse[d].dtype, 0, &buffer, values);
533     }
534   }
535 
536   // Merge SparseBuffers from all minibatches for every config.dense having
537   // variable_length.
538   for (size_t d = 0; d < config.dense.size(); ++d) {
539     if (!config.dense[d].variable_length) {
540       continue;
541     }
542     size_t max_num_features = 0;
543     std::vector<size_t>& end_indices =
544         varlen_dense_buffers[d].example_end_indices;
545     max_num_features = std::max(max_num_features, end_indices[0]);
546     for (size_t i = 1; i < end_indices.size(); ++i) {
547       size_t example_size = end_indices[i] - end_indices[i - 1];
548       max_num_features = std::max(max_num_features, example_size);
549     }
550 
551     const size_t stride_size = config.dense[d].elements_per_stride;
552     const size_t max_num_elements = max_num_features / stride_size;
553     tf::TensorShape values_shape;
554     DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
555     const size_t batch_size = GetStringCount(serialized);
556     values_shape.AddDim(batch_size);
557     values_shape.AddDim(max_num_elements);
558     for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
559       values_shape.AddDim(config.dense[d].shape.dim_size(i));
560     }
561     TfLiteTensor* values = result->dense_values[d];
562     const size_t num_elements = GetTensorShape(values).FlatSize();
563 
564     // Nothing to write, exit early.
565     if (num_elements == 0) {
566       continue;
567     }
568 
569     const size_t num_elements_per_minibatch = num_elements / batch_size;
570     switch (config.dense[d].dtype) {
571       case tf::DT_INT64: {
572         FillAndCopyVarLen<tf::int64>(d, num_elements,
573                                      num_elements_per_minibatch, config,
574                                      varlen_dense_buffers, values);
575         break;
576       }
577       case tf::DT_FLOAT: {
578         FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
579                                  config, varlen_dense_buffers, values);
580         break;
581       }
582       default:
583         DCHECK(false) << "Encountered unexpected DataType "
584                       << config.dense[d].dtype
585                       << "in variable that should have been checked";
586     }
587   }
588 
589   // Merge tflite string buffers if necessary.
590   for (size_t d = 0; d < config.dense.size(); ++d) {
591     if (config.dense[d].variable_length) {
592       continue;
593     }
594     if (result->dense_values[d]->type == kTfLiteString) {
595       auto& in = result->dense_tensors[d];
596       auto vec = in.vec<tstring>();
597       const int batch_size = result->dense_values[d]->dims->data[0];
598       const int elements_per_stride = config.dense[d].elements_per_stride;
599       int total_size = 0;
600       std::vector<int32_t> offsets;
601       offsets.reserve(vec.size() + 1);
602       offsets.push_back(0);
603       int k = 0;
604       for (int i = 0; i < batch_size; ++i) {
605         for (int j = 0; j < elements_per_stride; ++j) {
606           if (i < count) {
607             total_size += vec(k++).size();
608             offsets.push_back(total_size);
609           } else {
610             offsets.push_back(total_size);
611           }
612         }
613       }
614       const int32_t num_strings = offsets.size() - 1;
615       const size_t required_bytes = sizeof(int32_t) * (num_strings + 2) +
616           total_size;
617       char* tensor_buffer =
618           reinterpret_cast<char*>(result->dense_values[d]->data.raw);
619       if (result->dense_values[d]->bytes < required_bytes) {
620         if (result->dense_values[d]->data.raw) {
621           free(result->dense_values[d]->data.raw);
622         }
623         tensor_buffer = reinterpret_cast<char*>(malloc(required_bytes));
624         result->dense_values[d]->data.raw = tensor_buffer;
625         result->dense_values[d]->bytes = required_bytes;
626       }
627       const int32_t start = sizeof(int32_t) * (num_strings + 2);
628       memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
629       for (size_t i = 0; i < offsets.size(); i++) {
630         int32_t offset_i = start + offsets[i];
631         memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset_i,
632                sizeof(int32_t));
633       }
634       tf::gtl::ArraySlice<tstring> slice(vec.data(), vec.size());
635       CopyToBuffer(slice, tensor_buffer + start, count, batch_size,
636                    elements_per_stride);
637     }
638   }
639   return Status::OK();
640 }
641 
642 }  // namespace
643 
644 enum InputTensor {
645   kExampleTensor = 0,
646   kNamesTensor = 1,
647   kSparseKeysTensor = 2,
648   kDenseKeysTensor = 3,
649   kRaggedKeysTensor = 4,
650 };
651 
652 struct OpData {
653   FastParseExampleConfig config;
654   std::vector<tf::TensorShape> dense_shapes;
655   int dense_size = 0;
656   int sparse_size = 0;
657   std::unique_ptr<ConfigIndex> config_index;
658   int config_index_size;
659   SeededHasher hasher;
660   TfLiteResult got;
661   bool* quick_filter = nullptr;
662   int quick_filter_size;
663   bool created = false;
~OpDatatflite::ops::custom::parse_example::OpData664   ~OpData() {
665     if (quick_filter) {
666       free(quick_filter);
667     }
668   }
669 };
670 
Init(TfLiteContext * context,const char * buffer,size_t length)671 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
672   return new OpData;
673 }
674 
675 template <typename T>
AsTensor(const std::vector<T> & val)676 tf::Tensor AsTensor(const std::vector<T>& val) {
677   tf::Tensor ret(tf::DataTypeToEnum<T>::value,
678                  {static_cast<tf::int64>(val.size())});
679   std::copy_n(val.begin(), val.size(), ret.flat<T>().data());
680   return ret;
681 }
682 
683 enum Version {
684   V1,
685   V2,
686 };
687 
TfLiteToTfShape(TfLiteIntArray * array)688 tf::TensorShape TfLiteToTfShape(TfLiteIntArray* array) {
689   tf::TensorShape shape;
690   for (int i = 0; i < array->size; i++) {
691     shape.AddDim(array->data[i]);
692   }
693   return shape;
694 }
695 
696 template <Version version>
PrepareParseExample(TfLiteContext * context,TfLiteNode * node)697 TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
698   OpData* data = reinterpret_cast<OpData*>(node->user_data);
699   TF_LITE_ENSURE(context, node->custom_initial_data);
700   data->config.dense.clear();
701   data->config.sparse.clear();
702   data->got.dense_values.clear();
703   const flexbuffers::Vector& v =
704       flexbuffers::GetRoot(
705           reinterpret_cast<const uint8_t*>(node->custom_initial_data),
706           node->custom_initial_data_size)
707           .AsVector();
708   if (v.size() == 2) {
709     tf::NodeDef nodedef;
710     TF_LITE_ENSURE_EQ(context, nodedef.ParseFromString(v[1].AsString().str()),
711                       true);
712     if (version == V1) {
713       data->dense_size = nodedef.attr().at("Ndense").i();
714       data->sparse_size = nodedef.attr().at("Nsparse").i();
715     } else if (version == V2) {
716       data->dense_size = nodedef.attr().at("Tdense").list().type_size();
717       data->sparse_size = nodedef.attr().at("num_sparse").i();
718     }
719     auto dense_shapes = nodedef.attr().at("dense_shapes").list();
720     for (int i = 0; i < dense_shapes.shape_size(); ++i) {
721       data->dense_shapes.push_back(dense_shapes.shape(i));
722     }
723   } else {
724     const flexbuffers::Map& m =
725         flexbuffers::GetRoot(
726             reinterpret_cast<const uint8_t*>(node->custom_initial_data),
727             node->custom_initial_data_size)
728             .AsMap();
729     const flexbuffers::TypedVector keys = m.Keys();
730     int num_sparse = 0;
731     int num_dense = 0;
732     for (int k = 0; k < keys.size(); ++k) {
733       const std::string key = keys[k].ToString();
734       const auto value = m[key];
735       if (key == "Nsparse" || key == "num_sparse") {
736         num_sparse = value.AsInt32();
737       }
738       if (key == "Ndense") {
739         num_dense = value.AsInt32();
740       }
741     }
742     data->sparse_size = num_sparse;
743     data->dense_size = num_dense;
744     if (version == V2) {
745       const TfLiteTensor* dense_key_tensor =
746           GetInput(context, node, kDenseKeysTensor);
747       data->dense_size = GetTensorShape(dense_key_tensor).FlatSize();
748     }
749   }
750 
751   data->config.dense.reserve(data->dense_size);
752   data->config.sparse.reserve(data->sparse_size);
753   data->dense_shapes.reserve(data->dense_size);
754   const auto* serialized = GetInput(context, node, 0);
755   const int batch_size =
756       serialized->dims->size > 0 ? serialized->dims->data[0] : 1;
757   const bool missing_shape_info = data->dense_shapes.empty();
758   for (int i = 0; i < data->dense_size; i++) {
759     TfLiteTensor* dense_key_tensor =
760         GetOutput(context, node, data->sparse_size * 3 + i);
761     TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
762     if (missing_shape_info) {
763       RuntimeShape runtime_shape = GetTensorShape(dense_key_tensor);
764       data->dense_shapes.push_back(TfLiteToTfShape(output_size));
765     }
766     output_size->data[0] = batch_size * output_size->data[0];
767     context->ResizeTensor(context, dense_key_tensor, output_size);
768   }
769 
770   size_t offset = 0;
771   for (int i = 0; i < data->sparse_size; i++) {
772     auto* parse_output = GetOutput(context, node, i + offset);
773     SetTensorToDynamic(parse_output);
774     TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(2);
775     sparse_size->data[0] = batch_size;
776     sparse_size->data[1] = 2;
777     context->ResizeTensor(context, parse_output, sparse_size);
778     data->got.sparse_indices.push_back(parse_output);
779   }
780   offset += data->sparse_size;
781   for (int i = 0; i < data->sparse_size; i++) {
782     auto* parse_output = GetOutput(context, node, i + offset);
783     SetTensorToDynamic(parse_output);
784     TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
785     sparse_size->data[0] = 0;
786     context->ResizeTensor(context, parse_output, sparse_size);
787     data->got.sparse_values.push_back(parse_output);
788   }
789   offset += data->sparse_size;
790   for (int i = 0; i < data->sparse_size; i++) {
791     TfLiteTensor* parse_output = GetOutput(context, node, i + offset);
792     SetTensorToDynamic(parse_output);
793     TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
794     sparse_size->data[0] = 2;
795     context->ResizeTensor(context, parse_output, sparse_size);
796     auto* shapes_shape_t = reinterpret_cast<int64_t*>(parse_output->data.i64);
797     shapes_shape_t[0] = batch_size;
798     shapes_shape_t[1] = 1;
799     data->got.sparse_shapes.push_back(parse_output);
800   }
801   data->created = false;
802   return kTfLiteOk;
803 }
804 
805 template <Version version>
EvalParseExample(TfLiteContext * context,TfLiteNode * node)806 TfLiteStatus EvalParseExample(TfLiteContext* context, TfLiteNode* node) {
807   OpData* data = reinterpret_cast<OpData*>(node->user_data);
808   if (!data->created) {
809     for (int i = 0; i < data->sparse_size; i++) {
810       int input_index =
811           version == V1 ? kSparseKeysTensor + i : kSparseKeysTensor;
812       int string_index = version == V1 ? 0 : i;
813       const TfLiteTensor* sparse_key_tensor =
814           GetInput(context, node, input_index);
815       const auto key = GetString(sparse_key_tensor, string_index);
816       const auto* sparse_output =
817           GetOutput(context, node, i + data->sparse_size);
818       std::string k(key.str, key.len);
819       switch (sparse_output->type) {
820         case kTfLiteInt64:
821           data->config.sparse.emplace_back(
822               k, tf::DataTypeToEnum<tf::int64>::value);
823           break;
824         case kTfLiteFloat32:
825           data->config.sparse.emplace_back(k, tf::DataTypeToEnum<float>::value);
826           break;
827         case kTfLiteString:
828           data->config.sparse.emplace_back(k,
829                                            tf::DataTypeToEnum<tstring>::value);
830           break;
831         default:
832           return kTfLiteError;
833       }
834     }
835 
836     const auto& dense_shapes = data->dense_shapes;
837     for (int i = 0; i < data->dense_size; i++) {
838       const int input_index = version == V1
839                                   ? kSparseKeysTensor + data->sparse_size + i
840                                   : kDenseKeysTensor;
841       const int dense_defaults_index =
842           version == V1
843               ? kSparseKeysTensor + data->sparse_size + data->dense_size + i
844               : kRaggedKeysTensor + i + 1;
845       int string_index = version == V1 ? 0 : i;
846       const TfLiteTensor* dense_key_tensor =
847           GetInput(context, node, input_index);
848       const auto* dense_output =
849           GetOutput(context, node, i + data->sparse_size * 3);
850       const auto* dense_defaults =
851           GetInput(context, node, dense_defaults_index);
852       const auto key = GetString(dense_key_tensor, string_index);
853       std::string k(key.str, key.len);
854       const int elements_per_stride =
855           dense_shapes[i].dims() ? dense_shapes[i].num_elements() : 1;
856       switch (dense_output->type) {
857         case kTfLiteInt64:
858           data->config.dense.emplace_back(
859               k, tf::DataTypeToEnum<tf::int64>::value, dense_shapes[i],
860               AsTensor<tf::int64>(std::vector<tf::int64>(
861                   dense_defaults->data.i64,
862                   dense_defaults->data.i64 + elements_per_stride)),
863               false, elements_per_stride);
864           break;
865         case kTfLiteFloat32:
866           data->config.dense.emplace_back(
867               k, tf::DataTypeToEnum<float>::value, dense_shapes[i],
868               AsTensor<float>(std::vector<float>(
869                   dense_defaults->data.f,
870                   dense_defaults->data.f + elements_per_stride)),
871               false, elements_per_stride);
872           break;
873         case kTfLiteString: {
874           const int num_strings = GetStringCount(dense_defaults);
875           std::vector<tstring> values;
876           for (int i = 0; i < num_strings; ++i) {
877             auto ref = GetString(dense_defaults, i);
878             values.emplace_back(ref.str, ref.len);
879           }
880           data->config.dense.emplace_back(
881               k, tf::DataTypeToEnum<tstring>::value, dense_shapes[i],
882               AsTensor<tstring>(values), false, elements_per_stride);
883           break;
884         }
885         default:
886           return kTfLiteError;
887       }
888     }
889 
890     int offset = 3 * data->sparse_size;
891     for (int i = 0; i < data->dense_size; i++) {
892       auto* parse_output = GetOutput(context, node, i + offset);
893       data->got.dense_values.push_back(parse_output);
894       if (parse_output->type == kTfLiteString) {
895         tf::TensorShape shape;
896         if (parse_output->dims->size == 1) {
897           shape.AddDim(parse_output->dims->data[0]);
898         } else {
899           shape.AddDim(GetTensorShape(parse_output).FlatSize());
900         }
901         data->got.dense_tensors[i] =
902             tf::Tensor(tf::DataTypeToEnum<tstring>::value, shape);
903       }
904     }
905 
906     size_t config_size = data->config.dense.size();
907     config_size += data->config.sparse.size();
908     data->config_index_size = config_size;
909     auto config_index = std::make_unique<ConfigIndex>(config_size);
910     bool ok = true;
911     int max_length = 0;
912     for (size_t d = 0; d < data->config.dense.size(); ++d) {
913       auto s = data->config.dense[d].feature_name;
914       max_length = s.length() > max_length ? s.length() : max_length;
915     }
916     for (size_t d = 0; d < data->config.sparse.size(); ++d) {
917       auto s = data->config.sparse[d].feature_name;
918       max_length = s.length() > max_length ? s.length() : max_length;
919     }
920     if (data->quick_filter) {
921       free(data->quick_filter);
922     }
923     data->quick_filter =
924         static_cast<bool*>(malloc(++max_length * sizeof(bool)));
925     memset(data->quick_filter, 0, max_length * sizeof(bool));
926     data->quick_filter_size = max_length;
927     for (size_t d = 0; d < data->config.dense.size(); ++d) {
928       const auto& s = data->config.dense[d].feature_name;
929       data->quick_filter[s.length()] = true;
930     }
931     for (size_t d = 0; d < data->config.sparse.size(); ++d) {
932       const auto& s = data->config.sparse[d].feature_name;
933       data->quick_filter[s.length()] = true;
934     }
935 
936     for (int i = 0; i < 1000; ++i) {
937       for (size_t d = 0; d < data->config.dense.size(); ++d) {
938         ok &= config_index->InsertUnique(
939             data->hasher(data->config.dense[d].feature_name), {d, Type::Dense});
940       }
941       for (size_t d = 0; d < data->config.sparse.size(); ++d) {
942         ok &= config_index->InsertUnique(
943             data->hasher(data->config.sparse[d].feature_name),
944             {d, Type::Sparse});
945       }
946       if (ok) {
947         break;
948       }
949       data->hasher.seed++;
950       config_index->Clear(config_size);
951       ok = true;
952     }
953     if (!ok) {
954       return kTfLiteError;
955     }
956     data->config_index = std::move(config_index);
957     data->created = true;
958   }
959 
960   const TfLiteTensor* serialized = GetInput(context, node, kExampleTensor);
961 
962   std::map<absl::string_view, int> stats;
963   const auto status = FastParseExampleLite(
964       data->config, serialized, {}, data->quick_filter, data->quick_filter_size,
965       data->config_index, data->config_index_size, &data->hasher, &data->got,
966       stats, context);
967   if (status != tf::Status::OK()) {
968     TF_LITE_KERNEL_LOG(context, status.ToString().c_str());
969     return kTfLiteError;
970   }
971   return kTfLiteOk;
972 }
973 
Free(TfLiteContext * context,void * buffer)974 void Free(TfLiteContext* context, void* buffer) {
975   auto* obj = reinterpret_cast<OpData*>(buffer);
976   delete obj;
977 }
978 
979 }  // namespace parse_example
980 
Register_PARSE_EXAMPLE()981 TfLiteRegistration* Register_PARSE_EXAMPLE() {
982   static TfLiteRegistration r = {
983       parse_example::Init, parse_example::Free,
984       parse_example::PrepareParseExample<parse_example::V1>,
985       parse_example::EvalParseExample<parse_example::V1>};
986   return &r;
987 }
988 
Register_PARSE_EXAMPLE_V2()989 TfLiteRegistration* Register_PARSE_EXAMPLE_V2() {
990   static TfLiteRegistration r = {
991       parse_example::Init, parse_example::Free,
992       parse_example::PrepareParseExample<parse_example::V2>,
993       parse_example::EvalParseExample<parse_example::V2>};
994   return &r;
995 }
996 
AddParseExampleOp(::tflite::MutableOpResolver * resolver)997 extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver) {
998   resolver->AddCustom("ParseExample", Register_PARSE_EXAMPLE());
999   resolver->AddCustom("ParseExampleV2", Register_PARSE_EXAMPLE_V2());
1000 }
1001 
1002 }  // namespace custom
1003 }  // namespace ops
1004 }  // namespace tflite
1005