1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
17
18 #include <stddef.h>
19
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/status/status.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/string_type.h"
31 #include "tensorflow_lite_support/cc/common.h"
32 #include "tensorflow_lite_support/cc/port/status_macros.h"
33 #include "tensorflow_lite_support/cc/task/core/category.h"
34 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
35 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
36 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
37 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
39 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
40
41 namespace tflite {
42 namespace task {
43 namespace text {
44 namespace nlclassifier {
45
46 using ::tflite::support::CreateStatusWithPayload;
47 using ::tflite::support::StatusOr;
48 using ::tflite::support::TfLiteSupportStatus;
49 using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
50 using ::tflite::support::text::tokenizer::TokenizerResult;
51 using ::tflite::task::core::FindTensorByName;
52 using ::tflite::task::core::PopulateTensor;
53
54 namespace {
55 constexpr char kIdsTensorName[] = "ids";
56 constexpr char kMaskTensorName[] = "mask";
57 constexpr char kSegmentIdsTensorName[] = "segment_ids";
58 constexpr char kScoreTensorName[] = "probability";
59 constexpr char kClassificationToken[] = "[CLS]";
60 constexpr char kSeparator[] = "[SEP]";
61 constexpr int kTokenizerProcessUnitIndex = 0;
62 } // namespace
63
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)64 absl::Status BertNLClassifier::Preprocess(
65 const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
66 auto* input_tensor_metadatas =
67 GetMetadataExtractor()->GetInputTensorMetadata();
68 auto* ids_tensor =
69 FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName);
70 auto* mask_tensor =
71 FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName);
72 auto* segment_ids_tensor = FindTensorByName(
73 input_tensors, input_tensor_metadatas, kSegmentIdsTensorName);
74
75 std::string processed_input = input;
76 absl::AsciiStrToLower(&processed_input);
77
78 TokenizerResult input_tokenize_results;
79 input_tokenize_results = tokenizer_->Tokenize(processed_input);
80
81 // 2 accounts for [CLS], [SEP]
82 absl::Span<const std::string> query_tokens =
83 absl::MakeSpan(input_tokenize_results.subwords.data(),
84 input_tokenize_results.subwords.data() +
85 std::min(static_cast<size_t>(kMaxSeqLen - 2),
86 input_tokenize_results.subwords.size()));
87
88 std::vector<std::string> tokens;
89 tokens.reserve(2 + query_tokens.size());
90 // Start of generating the features.
91 tokens.push_back(kClassificationToken);
92 // For query input.
93 for (const auto& query_token : query_tokens) {
94 tokens.push_back(query_token);
95 }
96 // For Separation.
97 tokens.push_back(kSeparator);
98
99 std::vector<int> input_ids(kMaxSeqLen, 0);
100 std::vector<int> input_mask(kMaxSeqLen, 0);
101 // Convert tokens back into ids and set mask
102 for (int i = 0; i < tokens.size(); ++i) {
103 tokenizer_->LookupId(tokens[i], &input_ids[i]);
104 input_mask[i] = 1;
105 }
106 // |<-----------kMaxSeqLen---------->|
107 // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
108 // input_masks 1 1 1... 1 1 0 0... 0
109 // segment_ids 0 0 0... 0 0 0 0... 0
110
111 PopulateTensor(input_ids, ids_tensor);
112 PopulateTensor(input_mask, mask_tensor);
113 PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor);
114
115 return absl::OkStatus();
116 }
117
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)118 StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
119 const std::vector<const TfLiteTensor*>& output_tensors,
120 const std::string& /*input*/) {
121 if (output_tensors.size() != 1) {
122 return CreateStatusWithPayload(
123 absl::StatusCode::kInvalidArgument,
124 absl::StrFormat("BertNLClassifier models are expected to have only 1 "
125 "output, found %d",
126 output_tensors.size()),
127 TfLiteSupportStatus::kInvalidNumOutputTensorsError);
128 }
129 const TfLiteTensor* scores = FindTensorByName(
130 output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
131 kScoreTensorName);
132
133 // optional labels extracted from metadata
134 return BuildResults(scores, /*labels=*/nullptr);
135 }
136
137 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromFile(const std::string & path_to_model_with_metadata,std::unique_ptr<tflite::OpResolver> resolver)138 BertNLClassifier::CreateFromFile(
139 const std::string& path_to_model_with_metadata,
140 std::unique_ptr<tflite::OpResolver> resolver) {
141 std::unique_ptr<BertNLClassifier> bert_nl_classifier;
142 ASSIGN_OR_RETURN(bert_nl_classifier,
143 core::TaskAPIFactory::CreateFromFile<BertNLClassifier>(
144 path_to_model_with_metadata, std::move(resolver)));
145 RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
146 return std::move(bert_nl_classifier);
147 }
148
149 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromBuffer(const char * model_with_metadata_buffer_data,size_t model_with_metadata_buffer_size,std::unique_ptr<tflite::OpResolver> resolver)150 BertNLClassifier::CreateFromBuffer(
151 const char* model_with_metadata_buffer_data,
152 size_t model_with_metadata_buffer_size,
153 std::unique_ptr<tflite::OpResolver> resolver) {
154 std::unique_ptr<BertNLClassifier> bert_nl_classifier;
155 ASSIGN_OR_RETURN(bert_nl_classifier,
156 core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>(
157 model_with_metadata_buffer_data,
158 model_with_metadata_buffer_size, std::move(resolver)));
159 RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
160 return std::move(bert_nl_classifier);
161 }
162
CreateFromFd(int fd,std::unique_ptr<tflite::OpResolver> resolver)163 StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFd(
164 int fd, std::unique_ptr<tflite::OpResolver> resolver) {
165 std::unique_ptr<BertNLClassifier> bert_nl_classifier;
166 ASSIGN_OR_RETURN(
167 bert_nl_classifier,
168 core::TaskAPIFactory::CreateFromFileDescriptor<BertNLClassifier>(
169 fd, std::move(resolver)));
170 RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
171 return std::move(bert_nl_classifier);
172 }
173
InitializeFromMetadata()174 absl::Status BertNLClassifier::InitializeFromMetadata() {
175 // Set up mandatory tokenizer.
176 const ProcessUnit* tokenizer_process_unit =
177 GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
178 if (tokenizer_process_unit == nullptr) {
179 return CreateStatusWithPayload(
180 absl::StatusCode::kInvalidArgument,
181 "No input process unit found from metadata.",
182 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
183 }
184 ASSIGN_OR_RETURN(tokenizer_,
185 CreateTokenizerFromProcessUnit(tokenizer_process_unit,
186 GetMetadataExtractor()));
187
188 // Set up optional label vector.
189 TrySetLabelFromMetadata(
190 GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
191 .IgnoreError();
192 return absl::OkStatus();
193 }
194
195 } // namespace nlclassifier
196 } // namespace text
197 } // namespace task
198 } // namespace tflite
199