• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
17 
18 #include <cstddef>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/status/status.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/core/api/op_resolver.h"
31 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
32 #include "tensorflow_lite_support/cc/common.h"
33 #include "tensorflow_lite_support/cc/port/status_macros.h"
34 #include "tensorflow_lite_support/cc/port/statusor.h"
35 #include "tensorflow_lite_support/cc/task/core/category.h"
36 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
37 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
40 #include "tensorflow_lite_support/cc/utils/common_utils.h"
41 
42 namespace tflite {
43 namespace task {
44 namespace text {
45 namespace nlclassifier {
46 
47 using ::absl::StatusCode;
48 using ::flatbuffers::Offset;
49 using ::flatbuffers::Vector;
50 using ::tflite::TensorMetadata;
51 using ::tflite::support::CreateStatusWithPayload;
52 using ::tflite::support::StatusOr;
53 using ::tflite::support::TfLiteSupportStatus;
54 using ::tflite::support::text::tokenizer::RegexTokenizer;
55 using ::tflite::support::text::tokenizer::Tokenizer;
56 using ::tflite::support::text::tokenizer::TokenizerResult;
57 using ::tflite::support::utils::LoadVocabFromBuffer;
58 using ::tflite::task::core::Category;
59 using ::tflite::task::core::Dequantize;
60 using ::tflite::task::core::GetStringAtIndex;
61 using ::tflite::task::core::PopulateTensor;
62 
63 namespace {
64 constexpr int kRegexTokenizerInputTensorIndex = 0;
65 constexpr int kRegexTokenizerProcessUnitIndex = 0;
66 
CheckAndLoadFirstAssociatedFile(const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>> * associated_files,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)67 StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
68     const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
69         associated_files,
70     const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
71   if (associated_files == nullptr || associated_files->size() < 1 ||
72       associated_files->Get(0)->name() == nullptr) {
73     return CreateStatusWithPayload(
74         absl::StatusCode::kInvalidArgument,
75         "Invalid vocab_file from input process unit.",
76         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
77   }
78   ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
79                    metadata_extractor->GetAssociatedFile(
80                        associated_files->Get(0)->name()->str()));
81   return vocab_buffer;
82 }
83 
CreateRegexTokenizerFromProcessUnit(const tflite::ProcessUnit * tokenizer_process_unit,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)84 StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit(
85     const tflite::ProcessUnit* tokenizer_process_unit,
86     const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
87   if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
88     return CreateStatusWithPayload(
89         absl::StatusCode::kInvalidArgument,
90         "No metadata or input process unit found.",
91         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
92   }
93 
94   if (tokenizer_process_unit->options_type() !=
95       ProcessUnitOptions_RegexTokenizerOptions) {
96     return CreateStatusWithPayload(
97         absl::StatusCode::kNotFound,
98         absl::StrCat(
99             "Incorrect options_type:", tokenizer_process_unit->options_type(),
100             " need RegexTokenizerOptions."),
101         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
102   }
103 
104   const tflite::RegexTokenizerOptions* options =
105       tokenizer_process_unit->options_as<RegexTokenizerOptions>();
106   ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
107                    CheckAndLoadFirstAssociatedFile(options->vocab_file(),
108                                                    metadata_extractor));
109   if (options->delim_regex_pattern() == nullptr) {
110     return CreateStatusWithPayload(
111         absl::StatusCode::kInvalidArgument,
112         "Invalid delim_regex_pattern from input process unit.",
113         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
114   }
115 
116   std::unique_ptr<RegexTokenizer> regex_tokenizer =
117       absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
118                                         vocab_buffer.data(),
119                                         vocab_buffer.size());
120 
121   int unknown_token_id = 0;
122   if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
123     return CreateStatusWithPayload(
124         absl::StatusCode::kInvalidArgument,
125         "RegexTokenizer doesn't have <UNKNOWN> token.",
126         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
127   }
128 
129   int pad_token_id = 0;
130   if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
131     return CreateStatusWithPayload(
132         absl::StatusCode::kInvalidArgument,
133         "RegexTokenizer doesn't have <PAD> token.",
134         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
135   }
136   return regex_tokenizer;
137 }
138 
139 }  // namespace
140 
GetOptions() const141 const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; }
142 
TrySetLabelFromMetadata(const TensorMetadata * metadata)143 absl::Status NLClassifier::TrySetLabelFromMetadata(
144     const TensorMetadata* metadata) {
145   if (metadata == nullptr) {
146     return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
147                                    "Metadata not found for output tensor",
148                                    TfLiteSupportStatus::kMetadataNotFoundError);
149   }
150   const auto* associated_files = metadata->associated_files();
151   if (associated_files == nullptr || associated_files->size() == 0) {
152     return CreateStatusWithPayload(
153         absl::StatusCode::kInvalidArgument,
154         "No label file found for tensor metadata.",
155         TfLiteSupportStatus::kMetadataMissingLabelsError);
156   }
157   const tflite::AssociatedFile* associated_file =
158       associated_files->Get(kOutputTensorLabelFileIndex);
159   if (associated_file->type() != AssociatedFileType_TENSOR_AXIS_LABELS) {
160     return CreateStatusWithPayload(
161         absl::StatusCode::kInvalidArgument,
162         "Incorrect label type found for tensor metadata.",
163         TfLiteSupportStatus::kMetadataMissingLabelsError);
164   }
165   tflite::support::StatusOr<absl::string_view> label_buffer =
166       GetMetadataExtractor()->GetAssociatedFile(
167           associated_files->Get(kOutputTensorIndex)->name()->str());
168   if (label_buffer.ok()) {
169     labels_vector_ =
170         absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer(
171             label_buffer.value().data(), label_buffer.value().size()));
172     return absl::OkStatus();
173   } else {
174     return CreateStatusWithPayload(
175         absl::StatusCode::kInvalidArgument,
176         "Failed to extract label file from metadata.",
177         TfLiteSupportStatus::kMetadataMissingLabelsError);
178   }
179 }
180 
Classify(const std::string & text)181 std::vector<Category> NLClassifier::Classify(const std::string& text) {
182   // The NLClassifier implementation for Preprocess() and Postprocess() never
183   // returns errors: just call value().
184   return Infer(text).value();
185 }
186 
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)187 absl::Status NLClassifier::Preprocess(
188     const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
189   TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
190       input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
191       options_.input_tensor_name, options_.input_tensor_index);
192   if (input_tensor == nullptr) {
193     return CreateStatusWithPayload(
194         absl::StatusCode::kInvalidArgument,
195         "No input tensor found from NLClassifierOptions.",
196         TfLiteSupportStatus::kInputTensorNotFoundError);
197   }
198 
199   if (HasRegexTokenizerMetadata()) {
200     //                              |<-------sentence_length-------->|
201     // input_tensor                 <START>, t1, t2... <PAD>, <PAD>...
202     // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
203     // found in tokenizer vocab.
204     TokenizerResult result = tokenizer_->Tokenize(input);
205 
206     size_t max_sentence_length = input_tensor->dims->size == 2
207                                      ? input_tensor->dims->data[1]
208                                      : input_tensor->dims->data[0];
209 
210     int unknown_token_id = 0;
211     tokenizer_->GetUnknownToken(&unknown_token_id);
212 
213     int pad_token_id = 0;
214     tokenizer_->GetPadToken(&pad_token_id);
215 
216     std::vector<int> input_tokens(max_sentence_length, pad_token_id);
217     int start_token_id = 0;
218     size_t input_token_index = 0;
219     if (tokenizer_->GetStartToken(&start_token_id)) {
220       input_tokens[0] = start_token_id;
221       input_token_index = 1;
222     }
223 
224     for (size_t i = 0; (i < result.subwords.size()) &&
225                        (input_token_index < max_sentence_length);
226          ++i, ++input_token_index) {
227       const std::string& token = result.subwords[i];
228       int token_id = 0;
229       if (tokenizer_->LookupId(token, &token_id)) {
230         input_tokens[input_token_index] = token_id;
231       } else {
232         input_tokens[input_token_index] = unknown_token_id;
233       }
234     }
235 
236     PopulateTensor(input_tokens, input_tensor);
237   } else {
238     PopulateTensor(input, input_tensor);
239   }
240   return absl::OkStatus();
241 }
242 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)243 StatusOr<std::vector<Category>> NLClassifier::Postprocess(
244     const std::vector<const TfLiteTensor*>& output_tensors,
245     const std::string& /*input*/) {
246   return BuildResults(
247       FindTensorWithNameOrIndex(
248           output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
249           options_.output_score_tensor_name,
250           options_.output_score_tensor_index),
251       FindTensorWithNameOrIndex(
252           output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
253           options_.output_label_tensor_name,
254           options_.output_label_tensor_index));
255 }
256 
BuildResults(const TfLiteTensor * scores,const TfLiteTensor * labels)257 std::vector<Category> NLClassifier::BuildResults(const TfLiteTensor* scores,
258                                                  const TfLiteTensor* labels) {
259   bool use_index_as_labels = (labels_vector_ == nullptr) && (labels == nullptr);
260   // Some models output scores with transposed shape [1, categories]
261   int categories =
262       scores->dims->size == 2 ? scores->dims->data[1] : scores->dims->data[0];
263 
264   std::vector<Category> predictions;
265   predictions.reserve(categories);
266 
267   bool should_dequantize = scores->type == kTfLiteUInt8 ||
268                            scores->type == kTfLiteInt8 ||
269                            scores->type == kTfLiteInt16;
270   for (int index = 0; index < categories; index++) {
271     std::string label;
272     if (use_index_as_labels) {
273       label = std::to_string(index);
274     } else if (labels_vector_ == nullptr) {
275       if (labels->type == kTfLiteString) {
276         label = GetStringAtIndex(labels, index);
277       } else if (labels->type == kTfLiteInt32) {
278         label = std::to_string(GetTensorData<int>(labels)[index]);
279       }
280     } else {
281       label = (*labels_vector_)[index];
282     }
283     if (should_dequantize) {
284       predictions.push_back(Category(label, Dequantize(*scores, index)));
285     } else if (scores->type == kTfLiteBool) {
286       predictions.push_back(
287           Category(label, GetTensorData<bool>(scores)[index] ? 1.0 : 0.0));
288     } else {
289       predictions.push_back(
290           Category(label, scores->type == kTfLiteFloat32
291                               ? GetTensorData<float>(scores)[index]
292                               : GetTensorData<double>(scores)[index]));
293     }
294   }
295 
296   return predictions;
297 }
Initialize(const NLClassifierOptions & options)298 absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
299   options_ = options;
300   // input tensor should be type STRING
301   auto input_tensor = FindTensorWithNameOrIndex(
302       GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
303       options.input_tensor_name, options.input_tensor_index);
304   if (input_tensor == nullptr) {
305     return CreateStatusWithPayload(
306         StatusCode::kInvalidArgument,
307         absl::StrCat("No input tensor found with name ",
308                      options.input_tensor_name, " or at index ",
309                      options.input_tensor_index),
310         TfLiteSupportStatus::kInputTensorNotFoundError);
311   }
312   if (HasRegexTokenizerMetadata()) {
313     if (input_tensor->type != kTfLiteInt32) {
314       return CreateStatusWithPayload(
315           StatusCode::kInvalidArgument,
316           absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
317                        ". Requested INT32, got ",
318                        TfLiteTypeGetName(input_tensor->type), "."),
319           TfLiteSupportStatus::kInvalidInputTensorTypeError);
320     }
321     RETURN_IF_ERROR(SetupRegexTokenizer());
322   } else {
323     if (input_tensor->type != kTfLiteString) {
324       return CreateStatusWithPayload(
325           StatusCode::kInvalidArgument,
326           absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
327                        ". Requested STRING, got ",
328                        TfLiteTypeGetName(input_tensor->type), "."),
329           TfLiteSupportStatus::kInvalidInputTensorTypeError);
330     }
331   }
332 
333   // output score tensor should be type
334   // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
335   std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors();
336   const Vector<Offset<TensorMetadata>>* output_tensor_metadatas =
337       GetMetadataExtractor()->GetOutputTensorMetadata();
338 
339   const auto scores = FindTensorWithNameOrIndex(
340       output_tensors, output_tensor_metadatas, options.output_score_tensor_name,
341       options.output_score_tensor_index);
342   if (scores == nullptr) {
343     return CreateStatusWithPayload(
344         StatusCode::kInvalidArgument,
345         absl::StrCat("No output score tensor found with name ",
346                      options.output_score_tensor_name, " or at index ",
347                      options.output_score_tensor_index),
348         TfLiteSupportStatus::kOutputTensorNotFoundError);
349   }
350   static constexpr TfLiteType valid_types[] = {kTfLiteUInt8,   kTfLiteInt8,
351                                                kTfLiteInt16,   kTfLiteFloat32,
352                                                kTfLiteFloat64, kTfLiteBool};
353   if (!absl::c_linear_search(valid_types, scores->type)) {
354     return CreateStatusWithPayload(
355         StatusCode::kInvalidArgument,
356         absl::StrCat("Type mismatch for score tensor ", scores->name,
357                      ". Requested one of these types: "
358                      "INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ",
359                      TfLiteTypeGetName(scores->type), "."),
360         TfLiteSupportStatus::kInvalidOutputTensorTypeError);
361   }
362 
363   // Extract associated label file from output score tensor if one exists, a
364   // well-formatted metadata should have same number of tensors with the model.
365   if (output_tensor_metadatas &&
366       output_tensor_metadatas->size() == output_tensors.size()) {
367     for (int i = 0; i < output_tensor_metadatas->size(); ++i) {
368       const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i);
369       if ((metadata->name() && metadata->name()->string_view() ==
370                                    options.output_score_tensor_name) ||
371           i == options.output_score_tensor_index) {
372         if (TrySetLabelFromMetadata(metadata).ok()) {
373           return absl::OkStatus();
374         }
375       }
376     }
377   }
378 
379   // If labels_vector_ is not set up from metadata, try register output label
380   // tensor from options.
381   if (labels_vector_ == nullptr) {
382     // output label tensor should be type STRING or INT32 if the one exists
383     auto labels = FindTensorWithNameOrIndex(
384         output_tensors, output_tensor_metadatas,
385         options.output_label_tensor_name, options.output_label_tensor_index);
386     if (labels != nullptr && labels->type != kTfLiteString &&
387         labels->type != kTfLiteInt32) {
388       return CreateStatusWithPayload(
389           StatusCode::kInvalidArgument,
390           absl::StrCat("Type mismatch for label tensor ", scores->name,
391                        ". Requested STRING or INT32, got ",
392                        TfLiteTypeGetName(scores->type), "."),
393           TfLiteSupportStatus::kInvalidOutputTensorTypeError);
394     }
395   }
396   return absl::OkStatus();
397 }
398 
399 StatusOr<std::unique_ptr<NLClassifier>>
CreateFromBufferAndOptions(const char * model_buffer_data,size_t model_buffer_size,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)400 NLClassifier::CreateFromBufferAndOptions(
401     const char* model_buffer_data, size_t model_buffer_size,
402     const NLClassifierOptions& options,
403     std::unique_ptr<tflite::OpResolver> resolver) {
404   std::unique_ptr<NLClassifier> nl_classifier;
405   ASSIGN_OR_RETURN(
406       nl_classifier,
407       core::TaskAPIFactory::CreateFromBuffer<NLClassifier>(
408           model_buffer_data, model_buffer_size, std::move(resolver)));
409   RETURN_IF_ERROR(nl_classifier->Initialize(options));
410   return std::move(nl_classifier);
411 }
412 
CreateFromFileAndOptions(const std::string & path_to_model,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)413 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
414     const std::string& path_to_model, const NLClassifierOptions& options,
415     std::unique_ptr<tflite::OpResolver> resolver) {
416   std::unique_ptr<NLClassifier> nl_classifier;
417   ASSIGN_OR_RETURN(nl_classifier,
418                    core::TaskAPIFactory::CreateFromFile<NLClassifier>(
419                        path_to_model, std::move(resolver)));
420   RETURN_IF_ERROR(nl_classifier->Initialize(options));
421   return std::move(nl_classifier);
422 }
423 
CreateFromFdAndOptions(int fd,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)424 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
425     int fd, const NLClassifierOptions& options,
426     std::unique_ptr<tflite::OpResolver> resolver) {
427   std::unique_ptr<NLClassifier> nl_classifier;
428   ASSIGN_OR_RETURN(nl_classifier,
429                    core::TaskAPIFactory::CreateFromFileDescriptor<NLClassifier>(
430                        fd, std::move(resolver)));
431   RETURN_IF_ERROR(nl_classifier->Initialize(options));
432   return std::move(nl_classifier);
433 }
434 
HasRegexTokenizerMetadata()435 bool NLClassifier::HasRegexTokenizerMetadata() {
436   const TensorMetadata* input_tensor_metadata =
437       GetMetadataExtractor()->GetInputTensorMetadata(
438           kRegexTokenizerInputTensorIndex);
439   if (input_tensor_metadata == nullptr) {
440     return false;
441   }
442   tflite::support::StatusOr<const tflite::ProcessUnit*> status =
443       GetMetadataExtractor()->FindFirstProcessUnit(
444           *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
445   return status.ok() ? status.value() != nullptr : false;
446 }
447 
SetupRegexTokenizer()448 absl::Status NLClassifier::SetupRegexTokenizer() {
449   ASSIGN_OR_RETURN(
450       std::unique_ptr<Tokenizer> base_tokenizer,
451       CreateRegexTokenizerFromProcessUnit(
452           GetMetadataExtractor()
453               ->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
454               ->process_units()
455               ->Get(kRegexTokenizerProcessUnitIndex),
456           GetMetadataExtractor()));
457 
458   tokenizer_ = std::unique_ptr<RegexTokenizer>(
459       dynamic_cast<RegexTokenizer*>(base_tokenizer.release()));
460 
461   return absl::OkStatus();
462 }
463 
464 }  // namespace nlclassifier
465 }  // namespace text
466 }  // namespace task
467 }  // namespace tflite
468