1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
17
18 #include <cstddef>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/status/status.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/core/api/op_resolver.h"
31 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
32 #include "tensorflow_lite_support/cc/common.h"
33 #include "tensorflow_lite_support/cc/port/status_macros.h"
34 #include "tensorflow_lite_support/cc/port/statusor.h"
35 #include "tensorflow_lite_support/cc/task/core/category.h"
36 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
37 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
40 #include "tensorflow_lite_support/cc/utils/common_utils.h"
41
42 namespace tflite {
43 namespace task {
44 namespace text {
45 namespace nlclassifier {
46
47 using ::absl::StatusCode;
48 using ::flatbuffers::Offset;
49 using ::flatbuffers::Vector;
50 using ::tflite::TensorMetadata;
51 using ::tflite::support::CreateStatusWithPayload;
52 using ::tflite::support::StatusOr;
53 using ::tflite::support::TfLiteSupportStatus;
54 using ::tflite::support::text::tokenizer::RegexTokenizer;
55 using ::tflite::support::text::tokenizer::Tokenizer;
56 using ::tflite::support::text::tokenizer::TokenizerResult;
57 using ::tflite::support::utils::LoadVocabFromBuffer;
58 using ::tflite::task::core::Category;
59 using ::tflite::task::core::Dequantize;
60 using ::tflite::task::core::GetStringAtIndex;
61 using ::tflite::task::core::PopulateTensor;
62
63 namespace {
64 constexpr int kRegexTokenizerInputTensorIndex = 0;
65 constexpr int kRegexTokenizerProcessUnitIndex = 0;
66
CheckAndLoadFirstAssociatedFile(const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>> * associated_files,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)67 StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
68 const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
69 associated_files,
70 const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
71 if (associated_files == nullptr || associated_files->size() < 1 ||
72 associated_files->Get(0)->name() == nullptr) {
73 return CreateStatusWithPayload(
74 absl::StatusCode::kInvalidArgument,
75 "Invalid vocab_file from input process unit.",
76 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
77 }
78 ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
79 metadata_extractor->GetAssociatedFile(
80 associated_files->Get(0)->name()->str()));
81 return vocab_buffer;
82 }
83
CreateRegexTokenizerFromProcessUnit(const tflite::ProcessUnit * tokenizer_process_unit,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)84 StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit(
85 const tflite::ProcessUnit* tokenizer_process_unit,
86 const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
87 if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
88 return CreateStatusWithPayload(
89 absl::StatusCode::kInvalidArgument,
90 "No metadata or input process unit found.",
91 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
92 }
93
94 if (tokenizer_process_unit->options_type() !=
95 ProcessUnitOptions_RegexTokenizerOptions) {
96 return CreateStatusWithPayload(
97 absl::StatusCode::kNotFound,
98 absl::StrCat(
99 "Incorrect options_type:", tokenizer_process_unit->options_type(),
100 " need RegexTokenizerOptions."),
101 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
102 }
103
104 const tflite::RegexTokenizerOptions* options =
105 tokenizer_process_unit->options_as<RegexTokenizerOptions>();
106 ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
107 CheckAndLoadFirstAssociatedFile(options->vocab_file(),
108 metadata_extractor));
109 if (options->delim_regex_pattern() == nullptr) {
110 return CreateStatusWithPayload(
111 absl::StatusCode::kInvalidArgument,
112 "Invalid delim_regex_pattern from input process unit.",
113 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
114 }
115
116 std::unique_ptr<RegexTokenizer> regex_tokenizer =
117 absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
118 vocab_buffer.data(),
119 vocab_buffer.size());
120
121 int unknown_token_id = 0;
122 if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
123 return CreateStatusWithPayload(
124 absl::StatusCode::kInvalidArgument,
125 "RegexTokenizer doesn't have <UNKNOWN> token.",
126 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
127 }
128
129 int pad_token_id = 0;
130 if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
131 return CreateStatusWithPayload(
132 absl::StatusCode::kInvalidArgument,
133 "RegexTokenizer doesn't have <PAD> token.",
134 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
135 }
136 return regex_tokenizer;
137 }
138
139 } // namespace
140
GetOptions() const141 const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; }
142
TrySetLabelFromMetadata(const TensorMetadata * metadata)143 absl::Status NLClassifier::TrySetLabelFromMetadata(
144 const TensorMetadata* metadata) {
145 if (metadata == nullptr) {
146 return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
147 "Metadata not found for output tensor",
148 TfLiteSupportStatus::kMetadataNotFoundError);
149 }
150 const auto* associated_files = metadata->associated_files();
151 if (associated_files == nullptr || associated_files->size() == 0) {
152 return CreateStatusWithPayload(
153 absl::StatusCode::kInvalidArgument,
154 "No label file found for tensor metadata.",
155 TfLiteSupportStatus::kMetadataMissingLabelsError);
156 }
157 const tflite::AssociatedFile* associated_file =
158 associated_files->Get(kOutputTensorLabelFileIndex);
159 if (associated_file->type() != AssociatedFileType_TENSOR_AXIS_LABELS) {
160 return CreateStatusWithPayload(
161 absl::StatusCode::kInvalidArgument,
162 "Incorrect label type found for tensor metadata.",
163 TfLiteSupportStatus::kMetadataMissingLabelsError);
164 }
165 tflite::support::StatusOr<absl::string_view> label_buffer =
166 GetMetadataExtractor()->GetAssociatedFile(
167 associated_files->Get(kOutputTensorIndex)->name()->str());
168 if (label_buffer.ok()) {
169 labels_vector_ =
170 absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer(
171 label_buffer.value().data(), label_buffer.value().size()));
172 return absl::OkStatus();
173 } else {
174 return CreateStatusWithPayload(
175 absl::StatusCode::kInvalidArgument,
176 "Failed to extract label file from metadata.",
177 TfLiteSupportStatus::kMetadataMissingLabelsError);
178 }
179 }
180
Classify(const std::string & text)181 std::vector<Category> NLClassifier::Classify(const std::string& text) {
182 // The NLClassifier implementation for Preprocess() and Postprocess() never
183 // returns errors: just call value().
184 return Infer(text).value();
185 }
186
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)187 absl::Status NLClassifier::Preprocess(
188 const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
189 TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
190 input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
191 options_.input_tensor_name, options_.input_tensor_index);
192 if (input_tensor == nullptr) {
193 return CreateStatusWithPayload(
194 absl::StatusCode::kInvalidArgument,
195 "No input tensor found from NLClassifierOptions.",
196 TfLiteSupportStatus::kInputTensorNotFoundError);
197 }
198
199 if (HasRegexTokenizerMetadata()) {
200 // |<-------sentence_length-------->|
201 // input_tensor <START>, t1, t2... <PAD>, <PAD>...
202 // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
203 // found in tokenizer vocab.
204 TokenizerResult result = tokenizer_->Tokenize(input);
205
206 size_t max_sentence_length = input_tensor->dims->size == 2
207 ? input_tensor->dims->data[1]
208 : input_tensor->dims->data[0];
209
210 int unknown_token_id = 0;
211 tokenizer_->GetUnknownToken(&unknown_token_id);
212
213 int pad_token_id = 0;
214 tokenizer_->GetPadToken(&pad_token_id);
215
216 std::vector<int> input_tokens(max_sentence_length, pad_token_id);
217 int start_token_id = 0;
218 size_t input_token_index = 0;
219 if (tokenizer_->GetStartToken(&start_token_id)) {
220 input_tokens[0] = start_token_id;
221 input_token_index = 1;
222 }
223
224 for (size_t i = 0; (i < result.subwords.size()) &&
225 (input_token_index < max_sentence_length);
226 ++i, ++input_token_index) {
227 const std::string& token = result.subwords[i];
228 int token_id = 0;
229 if (tokenizer_->LookupId(token, &token_id)) {
230 input_tokens[input_token_index] = token_id;
231 } else {
232 input_tokens[input_token_index] = unknown_token_id;
233 }
234 }
235
236 PopulateTensor(input_tokens, input_tensor);
237 } else {
238 PopulateTensor(input, input_tensor);
239 }
240 return absl::OkStatus();
241 }
242
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)243 StatusOr<std::vector<Category>> NLClassifier::Postprocess(
244 const std::vector<const TfLiteTensor*>& output_tensors,
245 const std::string& /*input*/) {
246 return BuildResults(
247 FindTensorWithNameOrIndex(
248 output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
249 options_.output_score_tensor_name,
250 options_.output_score_tensor_index),
251 FindTensorWithNameOrIndex(
252 output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
253 options_.output_label_tensor_name,
254 options_.output_label_tensor_index));
255 }
256
BuildResults(const TfLiteTensor * scores,const TfLiteTensor * labels)257 std::vector<Category> NLClassifier::BuildResults(const TfLiteTensor* scores,
258 const TfLiteTensor* labels) {
259 bool use_index_as_labels = (labels_vector_ == nullptr) && (labels == nullptr);
260 // Some models output scores with transposed shape [1, categories]
261 int categories =
262 scores->dims->size == 2 ? scores->dims->data[1] : scores->dims->data[0];
263
264 std::vector<Category> predictions;
265 predictions.reserve(categories);
266
267 bool should_dequantize = scores->type == kTfLiteUInt8 ||
268 scores->type == kTfLiteInt8 ||
269 scores->type == kTfLiteInt16;
270 for (int index = 0; index < categories; index++) {
271 std::string label;
272 if (use_index_as_labels) {
273 label = std::to_string(index);
274 } else if (labels_vector_ == nullptr) {
275 if (labels->type == kTfLiteString) {
276 label = GetStringAtIndex(labels, index);
277 } else if (labels->type == kTfLiteInt32) {
278 label = std::to_string(GetTensorData<int>(labels)[index]);
279 }
280 } else {
281 label = (*labels_vector_)[index];
282 }
283 if (should_dequantize) {
284 predictions.push_back(Category(label, Dequantize(*scores, index)));
285 } else if (scores->type == kTfLiteBool) {
286 predictions.push_back(
287 Category(label, GetTensorData<bool>(scores)[index] ? 1.0 : 0.0));
288 } else {
289 predictions.push_back(
290 Category(label, scores->type == kTfLiteFloat32
291 ? GetTensorData<float>(scores)[index]
292 : GetTensorData<double>(scores)[index]));
293 }
294 }
295
296 return predictions;
297 }
Initialize(const NLClassifierOptions & options)298 absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
299 options_ = options;
300 // input tensor should be type STRING
301 auto input_tensor = FindTensorWithNameOrIndex(
302 GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
303 options.input_tensor_name, options.input_tensor_index);
304 if (input_tensor == nullptr) {
305 return CreateStatusWithPayload(
306 StatusCode::kInvalidArgument,
307 absl::StrCat("No input tensor found with name ",
308 options.input_tensor_name, " or at index ",
309 options.input_tensor_index),
310 TfLiteSupportStatus::kInputTensorNotFoundError);
311 }
312 if (HasRegexTokenizerMetadata()) {
313 if (input_tensor->type != kTfLiteInt32) {
314 return CreateStatusWithPayload(
315 StatusCode::kInvalidArgument,
316 absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
317 ". Requested INT32, got ",
318 TfLiteTypeGetName(input_tensor->type), "."),
319 TfLiteSupportStatus::kInvalidInputTensorTypeError);
320 }
321 RETURN_IF_ERROR(SetupRegexTokenizer());
322 } else {
323 if (input_tensor->type != kTfLiteString) {
324 return CreateStatusWithPayload(
325 StatusCode::kInvalidArgument,
326 absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
327 ". Requested STRING, got ",
328 TfLiteTypeGetName(input_tensor->type), "."),
329 TfLiteSupportStatus::kInvalidInputTensorTypeError);
330 }
331 }
332
333 // output score tensor should be type
334 // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
335 std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors();
336 const Vector<Offset<TensorMetadata>>* output_tensor_metadatas =
337 GetMetadataExtractor()->GetOutputTensorMetadata();
338
339 const auto scores = FindTensorWithNameOrIndex(
340 output_tensors, output_tensor_metadatas, options.output_score_tensor_name,
341 options.output_score_tensor_index);
342 if (scores == nullptr) {
343 return CreateStatusWithPayload(
344 StatusCode::kInvalidArgument,
345 absl::StrCat("No output score tensor found with name ",
346 options.output_score_tensor_name, " or at index ",
347 options.output_score_tensor_index),
348 TfLiteSupportStatus::kOutputTensorNotFoundError);
349 }
350 static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteInt8,
351 kTfLiteInt16, kTfLiteFloat32,
352 kTfLiteFloat64, kTfLiteBool};
353 if (!absl::c_linear_search(valid_types, scores->type)) {
354 return CreateStatusWithPayload(
355 StatusCode::kInvalidArgument,
356 absl::StrCat("Type mismatch for score tensor ", scores->name,
357 ". Requested one of these types: "
358 "INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ",
359 TfLiteTypeGetName(scores->type), "."),
360 TfLiteSupportStatus::kInvalidOutputTensorTypeError);
361 }
362
363 // Extract associated label file from output score tensor if one exists, a
364 // well-formatted metadata should have same number of tensors with the model.
365 if (output_tensor_metadatas &&
366 output_tensor_metadatas->size() == output_tensors.size()) {
367 for (int i = 0; i < output_tensor_metadatas->size(); ++i) {
368 const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i);
369 if ((metadata->name() && metadata->name()->string_view() ==
370 options.output_score_tensor_name) ||
371 i == options.output_score_tensor_index) {
372 if (TrySetLabelFromMetadata(metadata).ok()) {
373 return absl::OkStatus();
374 }
375 }
376 }
377 }
378
379 // If labels_vector_ is not set up from metadata, try register output label
380 // tensor from options.
381 if (labels_vector_ == nullptr) {
382 // output label tensor should be type STRING or INT32 if the one exists
383 auto labels = FindTensorWithNameOrIndex(
384 output_tensors, output_tensor_metadatas,
385 options.output_label_tensor_name, options.output_label_tensor_index);
386 if (labels != nullptr && labels->type != kTfLiteString &&
387 labels->type != kTfLiteInt32) {
388 return CreateStatusWithPayload(
389 StatusCode::kInvalidArgument,
390 absl::StrCat("Type mismatch for label tensor ", scores->name,
391 ". Requested STRING or INT32, got ",
392 TfLiteTypeGetName(scores->type), "."),
393 TfLiteSupportStatus::kInvalidOutputTensorTypeError);
394 }
395 }
396 return absl::OkStatus();
397 }
398
399 StatusOr<std::unique_ptr<NLClassifier>>
CreateFromBufferAndOptions(const char * model_buffer_data,size_t model_buffer_size,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)400 NLClassifier::CreateFromBufferAndOptions(
401 const char* model_buffer_data, size_t model_buffer_size,
402 const NLClassifierOptions& options,
403 std::unique_ptr<tflite::OpResolver> resolver) {
404 std::unique_ptr<NLClassifier> nl_classifier;
405 ASSIGN_OR_RETURN(
406 nl_classifier,
407 core::TaskAPIFactory::CreateFromBuffer<NLClassifier>(
408 model_buffer_data, model_buffer_size, std::move(resolver)));
409 RETURN_IF_ERROR(nl_classifier->Initialize(options));
410 return std::move(nl_classifier);
411 }
412
CreateFromFileAndOptions(const std::string & path_to_model,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)413 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
414 const std::string& path_to_model, const NLClassifierOptions& options,
415 std::unique_ptr<tflite::OpResolver> resolver) {
416 std::unique_ptr<NLClassifier> nl_classifier;
417 ASSIGN_OR_RETURN(nl_classifier,
418 core::TaskAPIFactory::CreateFromFile<NLClassifier>(
419 path_to_model, std::move(resolver)));
420 RETURN_IF_ERROR(nl_classifier->Initialize(options));
421 return std::move(nl_classifier);
422 }
423
CreateFromFdAndOptions(int fd,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)424 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
425 int fd, const NLClassifierOptions& options,
426 std::unique_ptr<tflite::OpResolver> resolver) {
427 std::unique_ptr<NLClassifier> nl_classifier;
428 ASSIGN_OR_RETURN(nl_classifier,
429 core::TaskAPIFactory::CreateFromFileDescriptor<NLClassifier>(
430 fd, std::move(resolver)));
431 RETURN_IF_ERROR(nl_classifier->Initialize(options));
432 return std::move(nl_classifier);
433 }
434
HasRegexTokenizerMetadata()435 bool NLClassifier::HasRegexTokenizerMetadata() {
436 const TensorMetadata* input_tensor_metadata =
437 GetMetadataExtractor()->GetInputTensorMetadata(
438 kRegexTokenizerInputTensorIndex);
439 if (input_tensor_metadata == nullptr) {
440 return false;
441 }
442 tflite::support::StatusOr<const tflite::ProcessUnit*> status =
443 GetMetadataExtractor()->FindFirstProcessUnit(
444 *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
445 return status.ok() ? status.value() != nullptr : false;
446 }
447
SetupRegexTokenizer()448 absl::Status NLClassifier::SetupRegexTokenizer() {
449 ASSIGN_OR_RETURN(
450 std::unique_ptr<Tokenizer> base_tokenizer,
451 CreateRegexTokenizerFromProcessUnit(
452 GetMetadataExtractor()
453 ->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
454 ->process_units()
455 ->Get(kRegexTokenizerProcessUnitIndex),
456 GetMetadataExtractor()));
457
458 tokenizer_ = std::unique_ptr<RegexTokenizer>(
459 dynamic_cast<RegexTokenizer*>(base_tokenizer.release()));
460
461 return absl::OkStatus();
462 }
463
464 } // namespace nlclassifier
465 } // namespace text
466 } // namespace task
467 } // namespace tflite
468