1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <memory>
20 #include <unordered_map>
21
22 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/fingerprint.h"
30 #include "tensorflow/core/public/session_options.h"
31 #include "tensorflow/core/util/example_proto_fast_parsing.h"
32 #include "tensorflow/core/util/presized_cuckoo_map.h"
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/tensor.h"
35 #include "tensorflow/lite/kernels/kernel_util.h"
36 #include "tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h"
37 #include "tensorflow/lite/mutable_op_resolver.h"
38 #include "tensorflow/lite/string_util.h"
39
40 namespace tflite {
41 namespace ops {
42 namespace custom {
43 namespace parse_example {
44 namespace {
45
46 namespace tf = ::tensorflow;
47 using tf::Status;
48 using tf::StringPiece;
49 using tf::tstring;
50 using tf::example::CopyOrMoveBlock;
51 using tf::example::FastParseExampleConfig;
52 using tf::example::GetListFromBuffer;
53 using tf::example::LimitedArraySlice;
54 using tf::example::ParseExample;
55 using tf::example::SeededHasher;
56 using tf::example::SmallVector;
57 using tf::example::SparseBuffer;
58 using tf::example::Type;
59 using tf::example::parsed::Example;
60
61 using ConfigIndex = tf::PresizedCuckooMap<std::pair<int32_t, Type>>;
62
63 struct TfLiteResult {
64 std::vector<TfLiteTensor*> dense_values;
65 std::vector<TfLiteTensor*> sparse_values;
66 std::vector<TfLiteTensor*> sparse_indices;
67 std::vector<TfLiteTensor*> sparse_shapes;
68 std::map<int, tf::Tensor> dense_tensors;
69 };
70
71 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const FastParseExampleConfig & config,std::vector<SparseBuffer> & varlen_dense_buffers,TfLiteTensor * values)72 void FillAndCopyVarLen(const int d, const size_t num_elements,
73 const size_t num_elements_per_minibatch,
74 const FastParseExampleConfig& config,
75 std::vector<SparseBuffer>& varlen_dense_buffers,
76 TfLiteTensor* values) {
77 const tf::Tensor& default_value = config.dense[d].default_value;
78
79 // Copy-fill the tensors (creating the zero/fill-padding)
80 std::fill(reinterpret_cast<T*>(values->data.raw),
81 reinterpret_cast<T*>(values->data.raw) + num_elements,
82 default_value.flat<T>()(0));
83
84 auto data = reinterpret_cast<T*>(values->data.raw);
85
86 const SparseBuffer& buffer = varlen_dense_buffers[d];
87 // Number of examples being stored in this buffer
88 const auto& end_indices = buffer.example_end_indices;
89 const size_t examples_in_buffer = end_indices.size();
90
91 const auto& list = GetListFromBuffer<T>(buffer);
92 auto list_ptr = list.begin();
93
94 size_t elements_tally = 0;
95 // Iterate through all the examples stored in this buffer.
96 for (size_t j = 0; j < examples_in_buffer; ++j) {
97 // Number of elements stored for this example.
98 const size_t num_elems = end_indices[j] - elements_tally;
99 CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
100 // Move forward this many elements in the varlen buffer.
101 list_ptr += num_elems;
102 // Move forward to the next minibatch entry in the values output.
103 data += num_elements_per_minibatch;
104 elements_tally = end_indices[j];
105 }
106 DCHECK(elements_tally == list.size());
107 }
108
ParseExample(StringRef serialized,Example * example)109 bool ParseExample(StringRef serialized, Example* example) {
110 DCHECK(example != nullptr);
111 tf::protobuf::io::CodedInputStream stream(
112 reinterpret_cast<const uint8*>(serialized.str), serialized.len);
113 tensorflow::example::EnableAliasing(&stream);
114 return ParseExample(&stream, example);
115 }
116
FastParseSerializedExample(StringRef serialized_example,const tstring & example_name,const size_t example_index,const FastParseExampleConfig & config,bool * quick_filter,int quick_filter_size,const std::unique_ptr<ConfigIndex> & config_index,int config_index_size,SeededHasher * hasher,std::vector<TfLiteTensor * > * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::map<absl::string_view,int> & stats,TfLiteResult * result)117 Status FastParseSerializedExample(
118 StringRef serialized_example, const tstring& example_name,
119 const size_t example_index, const FastParseExampleConfig& config,
120 bool* quick_filter, int quick_filter_size,
121 const std::unique_ptr<ConfigIndex>& config_index, int config_index_size,
122 SeededHasher* hasher, std::vector<TfLiteTensor*>* output_dense,
123 std::vector<SparseBuffer>* output_varlen_dense,
124 std::vector<SparseBuffer>* output_sparse,
125 std::map<absl::string_view, int>& stats, TfLiteResult* result) {
126 DCHECK(output_dense != nullptr);
127 tensorflow::example::parsed::Example parsed_example;
128 if (!ParseExample(serialized_example, &parsed_example)) {
129 return tf::errors::Internal("Failed to parse example");
130 }
131 std::vector<tf::int64> dense_feature_last_example(config.dense.size(), -1);
132 std::vector<tf::int64> sparse_feature_last_example(config.sparse.size(), -1);
133 // Handle features present in the example.
134 const size_t parsed_example_size = parsed_example.size();
135 for (size_t i = 0; i < parsed_example_size; ++i) {
136 // This is a logic that standard protobuf parsing is implementing.
137 // I.e. last entry in the map overwrites all the previous ones.
138 tensorflow::example::parsed::FeatureMapEntry& name_and_feature =
139 parsed_example[parsed_example_size - i - 1];
140 const StringPiece feature_name = name_and_feature.first;
141 tensorflow::example::parsed::Feature& feature = name_and_feature.second;
142 if (feature_name.length() >= quick_filter_size ||
143 !quick_filter[feature_name.length()]) {
144 continue;
145 }
146 const uint64_t h = (*hasher)(feature_name);
147 std::pair<int32_t, Type> d_and_type;
148 if (!config_index->Find(h, &d_and_type)) {
149 continue;
150 }
151 size_t d = d_and_type.first;
152 bool is_dense = d_and_type.second == Type::Dense;
153
154 auto example_error = [&](StringPiece suffix) {
155 return tf::errors::Internal("Name: ", example_name,
156 ", Key: ", feature_name,
157 ", Index: ", example_index, ". ", suffix);
158 };
159
160 auto parse_error = [&] {
161 return example_error("Can't parse serialized Example.");
162 };
163
164 tf::DataType example_dtype;
165 if (feature.ParseDataType(&example_dtype) != Status::OK()) {
166 return parse_error();
167 }
168 if (is_dense) {
169 if (example_dtype == tf::DT_INVALID) continue;
170
171 dense_feature_last_example[d] = example_index;
172
173 if (example_dtype != config.dense[d].dtype) {
174 return example_error(absl::StrCat(
175 "Data types don't match. Data type: ",
176 DataTypeString(example_dtype),
177 " but expected type: ", DataTypeString(config.dense[d].dtype)));
178 }
179 if (!config.dense[d].variable_length) {
180 TfLiteTensor* out = (*output_dense)[d];
181
182 const std::size_t num_elements = config.dense[d].elements_per_stride;
183 const std::size_t offset = example_index * num_elements;
184
185 auto shape_error = [&](size_t size, StringPiece type_str) {
186 return example_error(absl::StrCat(
187 "Number of ", type_str,
188 " values != expected. "
189 "Values size:",
190 size,
191 " but output shape: ", config.dense[d].shape.DebugString()));
192 };
193
194 switch (config.dense[d].dtype) {
195 case tf::DT_INT64: {
196 auto out_p = reinterpret_cast<tf::int64*>(out->data.raw) + offset;
197 LimitedArraySlice<tf::int64> slice(out_p, num_elements);
198 if (!feature.ParseInt64List(&slice)) return parse_error();
199 if (slice.EndDistance() != 0) {
200 return shape_error(num_elements - slice.EndDistance(), "int64");
201 }
202 break;
203 }
204 case tf::DT_FLOAT: {
205 auto out_p = reinterpret_cast<float*>(out->data.raw) + offset;
206 LimitedArraySlice<float> slice(out_p, num_elements);
207 if (!feature.ParseFloatList(&slice)) return parse_error();
208 if (slice.EndDistance() != 0) {
209 return shape_error(num_elements - slice.EndDistance(), "float");
210 }
211 break;
212 }
213 case tf::DT_STRING: {
214 auto& out_tensor = result->dense_tensors[d];
215 auto out_p = out_tensor.flat<tstring>().data() + offset;
216 LimitedArraySlice<tstring> slice(out_p, num_elements);
217 if (!feature.ParseBytesList(&slice)) return parse_error();
218 if (slice.EndDistance() != 0) {
219 return shape_error(num_elements - slice.EndDistance(), "bytes");
220 }
221 break;
222 }
223 default:
224 return tf::errors::Internal("Unrecognized dense type: ",
225 config.dense[d].dtype);
226 }
227 } else { // if dense variable length
228 SparseBuffer& out = (*output_varlen_dense)[d];
229
230 const std::size_t num_elements = config.dense[d].elements_per_stride;
231
232 if (example_dtype != tf::DT_INVALID &&
233 example_dtype != config.dense[d].dtype) {
234 return example_error(absl::StrCat(
235 "Data types don't match. ",
236 "Expected type: ", DataTypeString(config.dense[d].dtype)));
237 }
238
239 auto shape_error = [&](size_t size, StringPiece type_str) {
240 return example_error(
241 absl::StrCat("Number of ", type_str,
242 " values is not a multiple of stride length. Saw ",
243 size, " values but output shape is: ",
244 config.dense[d].shape.DebugString()));
245 };
246
247 switch (config.dense[d].dtype) {
248 case tf::DT_INT64: {
249 if (example_dtype != tf::DT_INVALID) {
250 if (!feature.ParseInt64List(&out.int64_list)) {
251 return parse_error();
252 }
253 if (out.int64_list.size() % num_elements != 0) {
254 return shape_error(out.int64_list.size(), "int64");
255 }
256 }
257 out.example_end_indices.push_back(out.int64_list.size());
258 break;
259 }
260 case tf::DT_FLOAT: {
261 if (example_dtype != tf::DT_INVALID) {
262 if (!feature.ParseFloatList(&out.float_list)) {
263 return parse_error();
264 }
265 if (out.float_list.size() % num_elements != 0) {
266 return shape_error(out.float_list.size(), "float");
267 }
268 }
269 out.example_end_indices.push_back(out.float_list.size());
270 break;
271 }
272 case tf::DT_STRING: {
273 if (example_dtype != tf::DT_INVALID) {
274 if (!feature.ParseBytesList(&out.bytes_list)) {
275 return parse_error();
276 }
277 if (out.bytes_list.size() % num_elements != 0) {
278 return shape_error(out.bytes_list.size(), "byte");
279 }
280 }
281 out.example_end_indices.push_back(out.bytes_list.size());
282 break;
283 }
284 default:
285 return tf::errors::Internal("Should not happen: ",
286 config.dense[d].dtype);
287 }
288 }
289 } else {
290 // is sparse or ragged
291 auto& last_example = sparse_feature_last_example;
292 if (last_example[d] == example_index) {
293 continue;
294 }
295 last_example[d] = example_index;
296 SparseBuffer& out = (*output_sparse)[d];
297 tf::DataType feature_dtype = config.sparse[d].dtype;
298 if (example_dtype != tf::DT_INVALID && example_dtype != feature_dtype) {
299 return tf::errors::Internal("Data types don't match:", example_dtype,
300 " != ", feature_dtype);
301 }
302 switch (feature_dtype) {
303 case tf::DT_INT64: {
304 if (example_dtype != tf::DT_INVALID) {
305 if (!feature.ParseInt64List(&out.int64_list)) {
306 return parse_error();
307 }
308 }
309 out.example_end_indices.push_back(out.int64_list.size());
310 break;
311 }
312 case tf::DT_FLOAT: {
313 if (example_dtype != tf::DT_INVALID) {
314 if (!feature.ParseFloatList(&out.float_list)) {
315 return parse_error();
316 }
317 }
318 out.example_end_indices.push_back(out.float_list.size());
319 break;
320 }
321 case tf::DT_STRING: {
322 if (example_dtype != tf::DT_INVALID) {
323 if (!feature.ParseBytesList(&out.bytes_list)) {
324 return parse_error();
325 }
326 }
327 out.example_end_indices.push_back(out.bytes_list.size());
328 break;
329 }
330 default:
331 return tf::errors::Internal("Should not happen: ", feature_dtype);
332 }
333 }
334 }
335 // Handle missing dense features for fixed strides.
336 for (size_t d = 0; d < config.dense.size(); ++d) {
337 if (config.dense[d].variable_length) continue;
338 if (dense_feature_last_example[d] == example_index) continue;
339 if (config.dense[d].default_value.NumElements() == 0) {
340 return tf::errors::Internal(
341 "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
342 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
343 " is required but could not be found.");
344 }
345 const tf::Tensor& in = config.dense[d].default_value;
346 TfLiteTensor* out = result->dense_values[d];
347 const std::size_t num_elements = in.shape().num_elements();
348 const std::size_t offset = example_index * num_elements;
349 switch (config.dense[d].dtype) {
350 case tf::DT_INT64: {
351 std::copy_n(in.flat<tf::int64>().data(), num_elements,
352 out->data.i64 + offset);
353 break;
354 }
355 case tf::DT_FLOAT: {
356 std::copy_n(in.flat<float>().data(), num_elements,
357 out->data.f + offset);
358 break;
359 }
360 case tf::DT_STRING: {
361 auto& out_tensor = result->dense_tensors[d];
362 std::copy_n(in.flat<tstring>().data(), num_elements,
363 out_tensor.flat<tstring>().data() + offset);
364 break;
365 }
366 default:
367 return tf::errors::Internal("Should not happen: ",
368 config.dense[d].dtype);
369 }
370 }
371 for (size_t d = 0; d < config.dense.size(); ++d) {
372 if (!config.dense[d].variable_length) continue;
373 if (dense_feature_last_example[d] == example_index) continue;
374 SparseBuffer& out = (*output_varlen_dense)[d];
375 size_t prev_example_end_index =
376 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
377 out.example_end_indices.push_back(prev_example_end_index);
378 }
379
380 for (size_t d = 0; d < config.sparse.size(); ++d) {
381 if (sparse_feature_last_example[d] == example_index) continue;
382 SparseBuffer& out = (*output_sparse)[d];
383 size_t prev_example_end_index =
384 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
385 out.example_end_indices.push_back(prev_example_end_index);
386 }
387
388 return Status::OK();
389 }
390
CountSparseFeatures(const SparseBuffer & sparse_buffer,size_t * total_num_features,size_t * max_num_features)391 void CountSparseFeatures(const SparseBuffer& sparse_buffer,
392 size_t* total_num_features, size_t* max_num_features) {
393 const std::vector<size_t>& end_indices = sparse_buffer.example_end_indices;
394 *total_num_features += end_indices.back();
395 *max_num_features = std::max(*max_num_features, end_indices[0]);
396 for (size_t i = 1; i < end_indices.size(); ++i) {
397 size_t example_size = end_indices[i] - end_indices[i - 1];
398 *max_num_features = std::max(*max_num_features, example_size);
399 }
400 }
401
CopySparseBufferToTensor(tf::DataType dtype,size_t offset,SparseBuffer * src,TfLiteTensor * dst)402 void CopySparseBufferToTensor(tf::DataType dtype, size_t offset,
403 SparseBuffer* src, TfLiteTensor* dst) {
404 switch (dtype) {
405 case tf::DT_INT64: {
406 std::copy(src->int64_list.begin(), src->int64_list.end(),
407 reinterpret_cast<int64_t*>(dst->data.raw) + offset);
408 break;
409 }
410 case tf::DT_FLOAT: {
411 std::copy(src->float_list.begin(), src->float_list.end(),
412 reinterpret_cast<float*>(dst->data.raw) + offset);
413 break;
414 }
415 case tf::DT_STRING: {
416 DynamicBuffer buffer;
417 for (auto* begin = src->bytes_list.begin();
418 begin != src->bytes_list.end(); begin++) {
419 buffer.AddString(begin->c_str(), begin->size());
420 }
421 buffer.WriteToTensor(dst, nullptr);
422 break;
423 }
424 default:
425 DCHECK(false) << "Encountered unexpected DataType "
426 << DataTypeString(dtype)
427 << "in variable that should have been checked.";
428 }
429 }
430
CopyToBuffer(tf::gtl::ArraySlice<tstring> vec,char * tensor_buffer,int num_examples,int batch_size,int elements_per_stride)431 inline void CopyToBuffer(tf::gtl::ArraySlice<tstring> vec, char* tensor_buffer,
432 int num_examples, int batch_size,
433 int elements_per_stride) {
434 int i = 0, k = 0;
435 int start = 0;
436 for (; i < num_examples; ++i) {
437 for (int j = 0; j < elements_per_stride; ++j) {
438 memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
439 start += vec[k].size();
440 k++;
441 }
442 }
443 // Will happen if the number of examples is less than the desired batch size.
444 for (; i < batch_size; ++i) {
445 for (int j = 0; j < elements_per_stride; ++j) {
446 memcpy(tensor_buffer + start, vec[k].c_str(), vec[k].size());
447 start += vec[k].size();
448 k++;
449 }
450 }
451 }
452
FastParseExampleLite(const FastParseExampleConfig & config,const TfLiteTensor * serialized,tf::gtl::ArraySlice<tstring> example_names,bool * quick_filter,int quick_filter_size,const std::unique_ptr<ConfigIndex> & config_index,int config_index_size,SeededHasher * hasher,TfLiteResult * result,std::map<absl::string_view,int> & stats,TfLiteContext * context)453 Status FastParseExampleLite(
454 const FastParseExampleConfig& config, const TfLiteTensor* serialized,
455 tf::gtl::ArraySlice<tstring> example_names, bool* quick_filter,
456 int quick_filter_size, const std::unique_ptr<ConfigIndex>& config_index,
457 int config_index_size, SeededHasher* hasher, TfLiteResult* result,
458 std::map<absl::string_view, int>& stats, TfLiteContext* context) {
459 if (result == nullptr) {
460 return tf::errors::Internal("Result is null");
461 }
462 const int count = GetStringCount(serialized);
463 std::vector<tf::Tensor> fixed_dense_values(config.dense.size());
464 std::vector<SparseBuffer> sparse_buffers(config.sparse.size());
465 std::vector<SparseBuffer> varlen_dense_buffers(config.dense.size());
466 Status status_of_minibatch;
467 for (size_t e = 0; e < count; ++e) {
468 Status status_of_minibatch = FastParseSerializedExample(
469 GetString(serialized, e),
470 (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
471 quick_filter, quick_filter_size, config_index, config_index_size,
472 hasher, &result->dense_values, &varlen_dense_buffers, &sparse_buffers,
473 /*arena,*/ stats, result);
474 if (!status_of_minibatch.ok()) break;
475 }
476 if (!status_of_minibatch.ok()) {
477 return status_of_minibatch;
478 }
479 // Merge SparseBuffers from all minibatches for every config.sparse.
480 // auto MergeSparseMinibatches = [&](size_t d) {
481 // Loop over minibatches
482 for (size_t d = 0; d < config.sparse.size(); ++d) {
483 size_t total_num_features = 0;
484 size_t max_num_features = 0;
485 CountSparseFeatures(sparse_buffers[d], &total_num_features,
486 &max_num_features);
487 tf::TensorShape indices_shape;
488 TfLiteTensor* indices = result->sparse_indices[d];
489 TfLiteTensor* values = result->sparse_values[d];
490
491 TfLiteTensor* dense_shape = result->sparse_shapes[d];
492 auto* dense_shape_ptr = reinterpret_cast<int64_t*>(dense_shape->data.raw);
493 dense_shape_ptr[1] = max_num_features;
494
495 TfLiteIntArray* index_shape = TfLiteIntArrayCreate(2);
496 index_shape->data[0] = total_num_features;
497 index_shape->data[1] = 2;
498 context->ResizeTensor(context, indices, index_shape);
499
500 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(1);
501 output_shape->data[0] = total_num_features;
502 context->ResizeTensor(context, values, output_shape);
503
504 SparseBuffer& buffer = sparse_buffers[d];
505
506 // Update indices.
507 auto* indices_p = reinterpret_cast<int64_t*>(indices->data.raw);
508 if (!indices_p) {
509 return tf::errors::Internal("Indices tensor not allocated!");
510 }
511
512 if (total_num_features > 0) {
513 int64_t* ix_p = indices_p;
514 size_t example_index = 0;
515 int idx0 = 0;
516 size_t delta = 0;
517 for (size_t example_end_index : buffer.example_end_indices) {
518 size_t feature_index = 0;
519 for (; delta < example_end_index; ++delta) {
520 // Column 0: example index
521 if (idx0 < total_num_features) {
522 *ix_p = example_index;
523 // Column 1: the feature index buffer example
524 *(ix_p + 1) = feature_index;
525 ix_p += 2;
526 }
527 ++feature_index;
528 ++idx0;
529 }
530 ++example_index;
531 }
532 CopySparseBufferToTensor(config.sparse[d].dtype, 0, &buffer, values);
533 }
534 }
535
536 // Merge SparseBuffers from all minibatches for every config.dense having
537 // variable_length.
538 for (size_t d = 0; d < config.dense.size(); ++d) {
539 if (!config.dense[d].variable_length) {
540 continue;
541 }
542 size_t max_num_features = 0;
543 std::vector<size_t>& end_indices =
544 varlen_dense_buffers[d].example_end_indices;
545 max_num_features = std::max(max_num_features, end_indices[0]);
546 for (size_t i = 1; i < end_indices.size(); ++i) {
547 size_t example_size = end_indices[i] - end_indices[i - 1];
548 max_num_features = std::max(max_num_features, example_size);
549 }
550
551 const size_t stride_size = config.dense[d].elements_per_stride;
552 const size_t max_num_elements = max_num_features / stride_size;
553 tf::TensorShape values_shape;
554 DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
555 const size_t batch_size = GetStringCount(serialized);
556 values_shape.AddDim(batch_size);
557 values_shape.AddDim(max_num_elements);
558 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
559 values_shape.AddDim(config.dense[d].shape.dim_size(i));
560 }
561 TfLiteTensor* values = result->dense_values[d];
562 const size_t num_elements = GetTensorShape(values).FlatSize();
563
564 // Nothing to write, exit early.
565 if (num_elements == 0) {
566 continue;
567 }
568
569 const size_t num_elements_per_minibatch = num_elements / batch_size;
570 switch (config.dense[d].dtype) {
571 case tf::DT_INT64: {
572 FillAndCopyVarLen<tf::int64>(d, num_elements,
573 num_elements_per_minibatch, config,
574 varlen_dense_buffers, values);
575 break;
576 }
577 case tf::DT_FLOAT: {
578 FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
579 config, varlen_dense_buffers, values);
580 break;
581 }
582 default:
583 DCHECK(false) << "Encountered unexpected DataType "
584 << config.dense[d].dtype
585 << "in variable that should have been checked";
586 }
587 }
588
589 // Merge tflite string buffers if necessary.
590 for (size_t d = 0; d < config.dense.size(); ++d) {
591 if (config.dense[d].variable_length) {
592 continue;
593 }
594 if (result->dense_values[d]->type == kTfLiteString) {
595 auto& in = result->dense_tensors[d];
596 auto vec = in.vec<tstring>();
597 const int batch_size = result->dense_values[d]->dims->data[0];
598 const int elements_per_stride = config.dense[d].elements_per_stride;
599 int total_size = 0;
600 std::vector<int32_t> offsets;
601 offsets.reserve(vec.size() + 1);
602 offsets.push_back(0);
603 int k = 0;
604 for (int i = 0; i < batch_size; ++i) {
605 for (int j = 0; j < elements_per_stride; ++j) {
606 if (i < count) {
607 total_size += vec(k++).size();
608 offsets.push_back(total_size);
609 } else {
610 offsets.push_back(total_size);
611 }
612 }
613 }
614 const int32_t num_strings = offsets.size() - 1;
615 const size_t required_bytes = sizeof(int32_t) * (num_strings + 2) +
616 total_size;
617 char* tensor_buffer =
618 reinterpret_cast<char*>(result->dense_values[d]->data.raw);
619 if (result->dense_values[d]->bytes < required_bytes) {
620 if (result->dense_values[d]->data.raw) {
621 free(result->dense_values[d]->data.raw);
622 }
623 tensor_buffer = reinterpret_cast<char*>(malloc(required_bytes));
624 result->dense_values[d]->data.raw = tensor_buffer;
625 result->dense_values[d]->bytes = required_bytes;
626 }
627 const int32_t start = sizeof(int32_t) * (num_strings + 2);
628 memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
629 for (size_t i = 0; i < offsets.size(); i++) {
630 int32_t offset_i = start + offsets[i];
631 memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset_i,
632 sizeof(int32_t));
633 }
634 tf::gtl::ArraySlice<tstring> slice(vec.data(), vec.size());
635 CopyToBuffer(slice, tensor_buffer + start, count, batch_size,
636 elements_per_stride);
637 }
638 }
639 return Status::OK();
640 }
641
642 } // namespace
643
644 enum InputTensor {
645 kExampleTensor = 0,
646 kNamesTensor = 1,
647 kSparseKeysTensor = 2,
648 kDenseKeysTensor = 3,
649 kRaggedKeysTensor = 4,
650 };
651
652 struct OpData {
653 FastParseExampleConfig config;
654 std::vector<tf::TensorShape> dense_shapes;
655 int dense_size = 0;
656 int sparse_size = 0;
657 std::unique_ptr<ConfigIndex> config_index;
658 int config_index_size;
659 SeededHasher hasher;
660 TfLiteResult got;
661 bool* quick_filter = nullptr;
662 int quick_filter_size;
663 bool created = false;
~OpDatatflite::ops::custom::parse_example::OpData664 ~OpData() {
665 if (quick_filter) {
666 free(quick_filter);
667 }
668 }
669 };
670
Init(TfLiteContext * context,const char * buffer,size_t length)671 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
672 return new OpData;
673 }
674
675 template <typename T>
AsTensor(const std::vector<T> & val)676 tf::Tensor AsTensor(const std::vector<T>& val) {
677 tf::Tensor ret(tf::DataTypeToEnum<T>::value,
678 {static_cast<tf::int64>(val.size())});
679 std::copy_n(val.begin(), val.size(), ret.flat<T>().data());
680 return ret;
681 }
682
683 enum Version {
684 V1,
685 V2,
686 };
687
TfLiteToTfShape(TfLiteIntArray * array)688 tf::TensorShape TfLiteToTfShape(TfLiteIntArray* array) {
689 tf::TensorShape shape;
690 for (int i = 0; i < array->size; i++) {
691 shape.AddDim(array->data[i]);
692 }
693 return shape;
694 }
695
696 template <Version version>
PrepareParseExample(TfLiteContext * context,TfLiteNode * node)697 TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
698 OpData* data = reinterpret_cast<OpData*>(node->user_data);
699 TF_LITE_ENSURE(context, node->custom_initial_data);
700 data->config.dense.clear();
701 data->config.sparse.clear();
702 data->got.dense_values.clear();
703 const flexbuffers::Vector& v =
704 flexbuffers::GetRoot(
705 reinterpret_cast<const uint8_t*>(node->custom_initial_data),
706 node->custom_initial_data_size)
707 .AsVector();
708 if (v.size() == 2) {
709 tf::NodeDef nodedef;
710 TF_LITE_ENSURE_EQ(context, nodedef.ParseFromString(v[1].AsString().str()),
711 true);
712 if (version == V1) {
713 data->dense_size = nodedef.attr().at("Ndense").i();
714 data->sparse_size = nodedef.attr().at("Nsparse").i();
715 } else if (version == V2) {
716 data->dense_size = nodedef.attr().at("Tdense").list().type_size();
717 data->sparse_size = nodedef.attr().at("num_sparse").i();
718 }
719 auto dense_shapes = nodedef.attr().at("dense_shapes").list();
720 for (int i = 0; i < dense_shapes.shape_size(); ++i) {
721 data->dense_shapes.push_back(dense_shapes.shape(i));
722 }
723 } else {
724 const flexbuffers::Map& m =
725 flexbuffers::GetRoot(
726 reinterpret_cast<const uint8_t*>(node->custom_initial_data),
727 node->custom_initial_data_size)
728 .AsMap();
729 const flexbuffers::TypedVector keys = m.Keys();
730 int num_sparse = 0;
731 int num_dense = 0;
732 for (int k = 0; k < keys.size(); ++k) {
733 const std::string key = keys[k].ToString();
734 const auto value = m[key];
735 if (key == "Nsparse" || key == "num_sparse") {
736 num_sparse = value.AsInt32();
737 }
738 if (key == "Ndense") {
739 num_dense = value.AsInt32();
740 }
741 }
742 data->sparse_size = num_sparse;
743 data->dense_size = num_dense;
744 if (version == V2) {
745 const TfLiteTensor* dense_key_tensor =
746 GetInput(context, node, kDenseKeysTensor);
747 data->dense_size = GetTensorShape(dense_key_tensor).FlatSize();
748 }
749 }
750
751 data->config.dense.reserve(data->dense_size);
752 data->config.sparse.reserve(data->sparse_size);
753 data->dense_shapes.reserve(data->dense_size);
754 const auto* serialized = GetInput(context, node, 0);
755 const int batch_size =
756 serialized->dims->size > 0 ? serialized->dims->data[0] : 1;
757 const bool missing_shape_info = data->dense_shapes.empty();
758 for (int i = 0; i < data->dense_size; i++) {
759 TfLiteTensor* dense_key_tensor =
760 GetOutput(context, node, data->sparse_size * 3 + i);
761 TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
762 if (missing_shape_info) {
763 RuntimeShape runtime_shape = GetTensorShape(dense_key_tensor);
764 data->dense_shapes.push_back(TfLiteToTfShape(output_size));
765 }
766 output_size->data[0] = batch_size * output_size->data[0];
767 context->ResizeTensor(context, dense_key_tensor, output_size);
768 }
769
770 size_t offset = 0;
771 for (int i = 0; i < data->sparse_size; i++) {
772 auto* parse_output = GetOutput(context, node, i + offset);
773 SetTensorToDynamic(parse_output);
774 TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(2);
775 sparse_size->data[0] = batch_size;
776 sparse_size->data[1] = 2;
777 context->ResizeTensor(context, parse_output, sparse_size);
778 data->got.sparse_indices.push_back(parse_output);
779 }
780 offset += data->sparse_size;
781 for (int i = 0; i < data->sparse_size; i++) {
782 auto* parse_output = GetOutput(context, node, i + offset);
783 SetTensorToDynamic(parse_output);
784 TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
785 sparse_size->data[0] = 0;
786 context->ResizeTensor(context, parse_output, sparse_size);
787 data->got.sparse_values.push_back(parse_output);
788 }
789 offset += data->sparse_size;
790 for (int i = 0; i < data->sparse_size; i++) {
791 TfLiteTensor* parse_output = GetOutput(context, node, i + offset);
792 SetTensorToDynamic(parse_output);
793 TfLiteIntArray* sparse_size = TfLiteIntArrayCreate(1);
794 sparse_size->data[0] = 2;
795 context->ResizeTensor(context, parse_output, sparse_size);
796 auto* shapes_shape_t = reinterpret_cast<int64_t*>(parse_output->data.i64);
797 shapes_shape_t[0] = batch_size;
798 shapes_shape_t[1] = 1;
799 data->got.sparse_shapes.push_back(parse_output);
800 }
801 data->created = false;
802 return kTfLiteOk;
803 }
804
805 template <Version version>
EvalParseExample(TfLiteContext * context,TfLiteNode * node)806 TfLiteStatus EvalParseExample(TfLiteContext* context, TfLiteNode* node) {
807 OpData* data = reinterpret_cast<OpData*>(node->user_data);
808 if (!data->created) {
809 for (int i = 0; i < data->sparse_size; i++) {
810 int input_index =
811 version == V1 ? kSparseKeysTensor + i : kSparseKeysTensor;
812 int string_index = version == V1 ? 0 : i;
813 const TfLiteTensor* sparse_key_tensor =
814 GetInput(context, node, input_index);
815 const auto key = GetString(sparse_key_tensor, string_index);
816 const auto* sparse_output =
817 GetOutput(context, node, i + data->sparse_size);
818 std::string k(key.str, key.len);
819 switch (sparse_output->type) {
820 case kTfLiteInt64:
821 data->config.sparse.emplace_back(
822 k, tf::DataTypeToEnum<tf::int64>::value);
823 break;
824 case kTfLiteFloat32:
825 data->config.sparse.emplace_back(k, tf::DataTypeToEnum<float>::value);
826 break;
827 case kTfLiteString:
828 data->config.sparse.emplace_back(k,
829 tf::DataTypeToEnum<tstring>::value);
830 break;
831 default:
832 return kTfLiteError;
833 }
834 }
835
836 const auto& dense_shapes = data->dense_shapes;
837 for (int i = 0; i < data->dense_size; i++) {
838 const int input_index = version == V1
839 ? kSparseKeysTensor + data->sparse_size + i
840 : kDenseKeysTensor;
841 const int dense_defaults_index =
842 version == V1
843 ? kSparseKeysTensor + data->sparse_size + data->dense_size + i
844 : kRaggedKeysTensor + i + 1;
845 int string_index = version == V1 ? 0 : i;
846 const TfLiteTensor* dense_key_tensor =
847 GetInput(context, node, input_index);
848 const auto* dense_output =
849 GetOutput(context, node, i + data->sparse_size * 3);
850 const auto* dense_defaults =
851 GetInput(context, node, dense_defaults_index);
852 const auto key = GetString(dense_key_tensor, string_index);
853 std::string k(key.str, key.len);
854 const int elements_per_stride =
855 dense_shapes[i].dims() ? dense_shapes[i].num_elements() : 1;
856 switch (dense_output->type) {
857 case kTfLiteInt64:
858 data->config.dense.emplace_back(
859 k, tf::DataTypeToEnum<tf::int64>::value, dense_shapes[i],
860 AsTensor<tf::int64>(std::vector<tf::int64>(
861 dense_defaults->data.i64,
862 dense_defaults->data.i64 + elements_per_stride)),
863 false, elements_per_stride);
864 break;
865 case kTfLiteFloat32:
866 data->config.dense.emplace_back(
867 k, tf::DataTypeToEnum<float>::value, dense_shapes[i],
868 AsTensor<float>(std::vector<float>(
869 dense_defaults->data.f,
870 dense_defaults->data.f + elements_per_stride)),
871 false, elements_per_stride);
872 break;
873 case kTfLiteString: {
874 const int num_strings = GetStringCount(dense_defaults);
875 std::vector<tstring> values;
876 for (int i = 0; i < num_strings; ++i) {
877 auto ref = GetString(dense_defaults, i);
878 values.emplace_back(ref.str, ref.len);
879 }
880 data->config.dense.emplace_back(
881 k, tf::DataTypeToEnum<tstring>::value, dense_shapes[i],
882 AsTensor<tstring>(values), false, elements_per_stride);
883 break;
884 }
885 default:
886 return kTfLiteError;
887 }
888 }
889
890 int offset = 3 * data->sparse_size;
891 for (int i = 0; i < data->dense_size; i++) {
892 auto* parse_output = GetOutput(context, node, i + offset);
893 data->got.dense_values.push_back(parse_output);
894 if (parse_output->type == kTfLiteString) {
895 tf::TensorShape shape;
896 if (parse_output->dims->size == 1) {
897 shape.AddDim(parse_output->dims->data[0]);
898 } else {
899 shape.AddDim(GetTensorShape(parse_output).FlatSize());
900 }
901 data->got.dense_tensors[i] =
902 tf::Tensor(tf::DataTypeToEnum<tstring>::value, shape);
903 }
904 }
905
906 size_t config_size = data->config.dense.size();
907 config_size += data->config.sparse.size();
908 data->config_index_size = config_size;
909 auto config_index = std::make_unique<ConfigIndex>(config_size);
910 bool ok = true;
911 int max_length = 0;
912 for (size_t d = 0; d < data->config.dense.size(); ++d) {
913 auto s = data->config.dense[d].feature_name;
914 max_length = s.length() > max_length ? s.length() : max_length;
915 }
916 for (size_t d = 0; d < data->config.sparse.size(); ++d) {
917 auto s = data->config.sparse[d].feature_name;
918 max_length = s.length() > max_length ? s.length() : max_length;
919 }
920 if (data->quick_filter) {
921 free(data->quick_filter);
922 }
923 data->quick_filter =
924 static_cast<bool*>(malloc(++max_length * sizeof(bool)));
925 memset(data->quick_filter, 0, max_length * sizeof(bool));
926 data->quick_filter_size = max_length;
927 for (size_t d = 0; d < data->config.dense.size(); ++d) {
928 const auto& s = data->config.dense[d].feature_name;
929 data->quick_filter[s.length()] = true;
930 }
931 for (size_t d = 0; d < data->config.sparse.size(); ++d) {
932 const auto& s = data->config.sparse[d].feature_name;
933 data->quick_filter[s.length()] = true;
934 }
935
936 for (int i = 0; i < 1000; ++i) {
937 for (size_t d = 0; d < data->config.dense.size(); ++d) {
938 ok &= config_index->InsertUnique(
939 data->hasher(data->config.dense[d].feature_name), {d, Type::Dense});
940 }
941 for (size_t d = 0; d < data->config.sparse.size(); ++d) {
942 ok &= config_index->InsertUnique(
943 data->hasher(data->config.sparse[d].feature_name),
944 {d, Type::Sparse});
945 }
946 if (ok) {
947 break;
948 }
949 data->hasher.seed++;
950 config_index->Clear(config_size);
951 ok = true;
952 }
953 if (!ok) {
954 return kTfLiteError;
955 }
956 data->config_index = std::move(config_index);
957 data->created = true;
958 }
959
960 const TfLiteTensor* serialized = GetInput(context, node, kExampleTensor);
961
962 std::map<absl::string_view, int> stats;
963 const auto status = FastParseExampleLite(
964 data->config, serialized, {}, data->quick_filter, data->quick_filter_size,
965 data->config_index, data->config_index_size, &data->hasher, &data->got,
966 stats, context);
967 if (status != tf::Status::OK()) {
968 TF_LITE_KERNEL_LOG(context, status.ToString().c_str());
969 return kTfLiteError;
970 }
971 return kTfLiteOk;
972 }
973
Free(TfLiteContext * context,void * buffer)974 void Free(TfLiteContext* context, void* buffer) {
975 auto* obj = reinterpret_cast<OpData*>(buffer);
976 delete obj;
977 }
978
979 } // namespace parse_example
980
Register_PARSE_EXAMPLE()981 TfLiteRegistration* Register_PARSE_EXAMPLE() {
982 static TfLiteRegistration r = {
983 parse_example::Init, parse_example::Free,
984 parse_example::PrepareParseExample<parse_example::V1>,
985 parse_example::EvalParseExample<parse_example::V1>};
986 return &r;
987 }
988
Register_PARSE_EXAMPLE_V2()989 TfLiteRegistration* Register_PARSE_EXAMPLE_V2() {
990 static TfLiteRegistration r = {
991 parse_example::Init, parse_example::Free,
992 parse_example::PrepareParseExample<parse_example::V2>,
993 parse_example::EvalParseExample<parse_example::V2>};
994 return &r;
995 }
996
AddParseExampleOp(::tflite::MutableOpResolver * resolver)997 extern "C" void AddParseExampleOp(::tflite::MutableOpResolver* resolver) {
998 resolver->AddCustom("ParseExample", Register_PARSE_EXAMPLE());
999 resolver->AddCustom("ParseExampleV2", Register_PARSE_EXAMPLE_V2());
1000 }
1001
1002 } // namespace custom
1003 } // namespace ops
1004 } // namespace tflite
1005