• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
17 
18 #include <stddef.h>
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/status/status.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/string_type.h"
31 #include "tensorflow_lite_support/cc/common.h"
32 #include "tensorflow_lite_support/cc/port/status_macros.h"
33 #include "tensorflow_lite_support/cc/task/core/category.h"
34 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
35 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
36 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
37 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
39 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
40 
41 namespace tflite {
42 namespace task {
43 namespace text {
44 namespace nlclassifier {
45 
46 using ::tflite::support::CreateStatusWithPayload;
47 using ::tflite::support::StatusOr;
48 using ::tflite::support::TfLiteSupportStatus;
49 using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
50 using ::tflite::support::text::tokenizer::TokenizerResult;
51 using ::tflite::task::core::FindTensorByName;
52 using ::tflite::task::core::PopulateTensor;
53 
54 namespace {
55 constexpr char kIdsTensorName[] = "ids";
56 constexpr char kMaskTensorName[] = "mask";
57 constexpr char kSegmentIdsTensorName[] = "segment_ids";
58 constexpr char kScoreTensorName[] = "probability";
59 constexpr char kClassificationToken[] = "[CLS]";
60 constexpr char kSeparator[] = "[SEP]";
61 constexpr int kTokenizerProcessUnitIndex = 0;
62 }  // namespace
63 
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)64 absl::Status BertNLClassifier::Preprocess(
65     const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
66   auto* input_tensor_metadatas =
67       GetMetadataExtractor()->GetInputTensorMetadata();
68   auto* ids_tensor =
69       FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName);
70   auto* mask_tensor =
71       FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName);
72   auto* segment_ids_tensor = FindTensorByName(
73       input_tensors, input_tensor_metadatas, kSegmentIdsTensorName);
74 
75   std::string processed_input = input;
76   absl::AsciiStrToLower(&processed_input);
77 
78   TokenizerResult input_tokenize_results;
79   input_tokenize_results = tokenizer_->Tokenize(processed_input);
80 
81   // 2 accounts for [CLS], [SEP]
82   absl::Span<const std::string> query_tokens =
83       absl::MakeSpan(input_tokenize_results.subwords.data(),
84                      input_tokenize_results.subwords.data() +
85                          std::min(static_cast<size_t>(kMaxSeqLen - 2),
86                                   input_tokenize_results.subwords.size()));
87 
88   std::vector<std::string> tokens;
89   tokens.reserve(2 + query_tokens.size());
90   // Start of generating the features.
91   tokens.push_back(kClassificationToken);
92   // For query input.
93   for (const auto& query_token : query_tokens) {
94     tokens.push_back(query_token);
95   }
96   // For Separation.
97   tokens.push_back(kSeparator);
98 
99   std::vector<int> input_ids(kMaxSeqLen, 0);
100   std::vector<int> input_mask(kMaxSeqLen, 0);
101   // Convert tokens back into ids and set mask
102   for (int i = 0; i < tokens.size(); ++i) {
103     tokenizer_->LookupId(tokens[i], &input_ids[i]);
104     input_mask[i] = 1;
105   }
106   //                             |<-----------kMaxSeqLen---------->|
107   // input_ids                 [CLS] s1  s2...  sn [SEP]  0  0...  0
108   // input_masks                 1    1   1...  1    1    0  0...  0
109   // segment_ids                 0    0   0...  0    0    0  0...  0
110 
111   PopulateTensor(input_ids, ids_tensor);
112   PopulateTensor(input_mask, mask_tensor);
113   PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor);
114 
115   return absl::OkStatus();
116 }
117 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)118 StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
119     const std::vector<const TfLiteTensor*>& output_tensors,
120     const std::string& /*input*/) {
121   if (output_tensors.size() != 1) {
122     return CreateStatusWithPayload(
123         absl::StatusCode::kInvalidArgument,
124         absl::StrFormat("BertNLClassifier models are expected to have only 1 "
125                         "output, found %d",
126                         output_tensors.size()),
127         TfLiteSupportStatus::kInvalidNumOutputTensorsError);
128   }
129   const TfLiteTensor* scores = FindTensorByName(
130       output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
131       kScoreTensorName);
132 
133   // optional labels extracted from metadata
134   return BuildResults(scores, /*labels=*/nullptr);
135 }
136 
137 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromFile(const std::string & path_to_model_with_metadata,std::unique_ptr<tflite::OpResolver> resolver)138 BertNLClassifier::CreateFromFile(
139     const std::string& path_to_model_with_metadata,
140     std::unique_ptr<tflite::OpResolver> resolver) {
141   std::unique_ptr<BertNLClassifier> bert_nl_classifier;
142   ASSIGN_OR_RETURN(bert_nl_classifier,
143                    core::TaskAPIFactory::CreateFromFile<BertNLClassifier>(
144                        path_to_model_with_metadata, std::move(resolver)));
145   RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
146   return std::move(bert_nl_classifier);
147 }
148 
149 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromBuffer(const char * model_with_metadata_buffer_data,size_t model_with_metadata_buffer_size,std::unique_ptr<tflite::OpResolver> resolver)150 BertNLClassifier::CreateFromBuffer(
151     const char* model_with_metadata_buffer_data,
152     size_t model_with_metadata_buffer_size,
153     std::unique_ptr<tflite::OpResolver> resolver) {
154   std::unique_ptr<BertNLClassifier> bert_nl_classifier;
155   ASSIGN_OR_RETURN(bert_nl_classifier,
156                    core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>(
157                        model_with_metadata_buffer_data,
158                        model_with_metadata_buffer_size, std::move(resolver)));
159   RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
160   return std::move(bert_nl_classifier);
161 }
162 
CreateFromFd(int fd,std::unique_ptr<tflite::OpResolver> resolver)163 StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFd(
164     int fd, std::unique_ptr<tflite::OpResolver> resolver) {
165   std::unique_ptr<BertNLClassifier> bert_nl_classifier;
166   ASSIGN_OR_RETURN(
167       bert_nl_classifier,
168       core::TaskAPIFactory::CreateFromFileDescriptor<BertNLClassifier>(
169           fd, std::move(resolver)));
170   RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
171   return std::move(bert_nl_classifier);
172 }
173 
InitializeFromMetadata()174 absl::Status BertNLClassifier::InitializeFromMetadata() {
175   // Set up mandatory tokenizer.
176   const ProcessUnit* tokenizer_process_unit =
177       GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
178   if (tokenizer_process_unit == nullptr) {
179     return CreateStatusWithPayload(
180         absl::StatusCode::kInvalidArgument,
181         "No input process unit found from metadata.",
182         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
183   }
184   ASSIGN_OR_RETURN(tokenizer_,
185                    CreateTokenizerFromProcessUnit(tokenizer_process_unit,
186                                                   GetMetadataExtractor()));
187 
188   // Set up optional label vector.
189   TrySetLabelFromMetadata(
190       GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
191       .IgnoreError();
192   return absl::OkStatus();
193 }
194 
195 }  // namespace nlclassifier
196 }  // namespace text
197 }  // namespace task
198 }  // namespace tflite
199