1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/pod_ner/pod-ner-impl.h"
18
19 #include <algorithm>
20 #include <cstdint>
21 #include <ctime>
22 #include <iostream>
23 #include <memory>
24 #include <ostream>
25 #include <unordered_set>
26 #include <vector>
27
28 #include "annotator/model_generated.h"
29 #include "annotator/pod_ner/utils.h"
30 #include "annotator/types.h"
31 #include "utils/base/logging.h"
32 #include "utils/bert_tokenizer.h"
33 #include "utils/tflite-model-executor.h"
34 #include "utils/tokenizer-utils.h"
35 #include "utils/utf8/unicodetext.h"
36 #include "absl/strings/ascii.h"
37 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
38 #include "tensorflow/lite/mutable_op_resolver.h"
39 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
40 #include "tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h"
41 #include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
42
43 namespace libtextclassifier3 {
44
45 using PodNerModel_::CollectionT;
46 using PodNerModel_::LabelT;
47 using ::tflite::support::text::tokenizer::TokenizerResult;
48
49 namespace {
50
51 using PodNerModel_::Label_::BoiseType;
52 using PodNerModel_::Label_::BoiseType_BEGIN;
53 using PodNerModel_::Label_::BoiseType_END;
54 using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
55 using PodNerModel_::Label_::BoiseType_O;
56 using PodNerModel_::Label_::BoiseType_SINGLE;
57 using PodNerModel_::Label_::MentionType;
58 using PodNerModel_::Label_::MentionType_NAM;
59 using PodNerModel_::Label_::MentionType_NOM;
60 using PodNerModel_::Label_::MentionType_UNDEFINED;
61
EmplaceToLabelVector(BoiseType boise_type,MentionType mention_type,int collection_id,std::vector<LabelT> * labels)62 void EmplaceToLabelVector(BoiseType boise_type, MentionType mention_type,
63 int collection_id, std::vector<LabelT> *labels) {
64 labels->emplace_back();
65 labels->back().boise_type = boise_type;
66 labels->back().mention_type = mention_type;
67 labels->back().collection_id = collection_id;
68 }
69
FillDefaultLabelsAndCollections(float default_priority,std::vector<LabelT> * labels,std::vector<CollectionT> * collections)70 void FillDefaultLabelsAndCollections(float default_priority,
71 std::vector<LabelT> *labels,
72 std::vector<CollectionT> *collections) {
73 std::vector<std::string> collection_names = {
74 "art", "consumer_good", "event", "location",
75 "organization", "ner_entity", "person", "undefined"};
76 collections->clear();
77 for (const std::string &collection_name : collection_names) {
78 collections->emplace_back();
79 collections->back().name = collection_name;
80 collections->back().single_token_priority_score = default_priority;
81 collections->back().multi_token_priority_score = default_priority;
82 }
83
84 labels->clear();
85 for (auto boise_type :
86 {BoiseType_BEGIN, BoiseType_END, BoiseType_INTERMEDIATE}) {
87 for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
88 for (int i = 0; i < collections->size() - 1; ++i) { // skip undefined
89 EmplaceToLabelVector(boise_type, mention_type, i, labels);
90 }
91 }
92 }
93 EmplaceToLabelVector(BoiseType_O, MentionType_UNDEFINED, 7, labels);
94 for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
95 for (int i = 0; i < collections->size() - 1; ++i) { // skip undefined
96 EmplaceToLabelVector(BoiseType_SINGLE, mention_type, i, labels);
97 }
98 }
99 }
100
CreateInterpreter(const PodNerModel * model)101 std::unique_ptr<tflite::Interpreter> CreateInterpreter(
102 const PodNerModel *model) {
103 TC3_CHECK(model != nullptr);
104 if (model->tflite_model() == nullptr) {
105 TC3_LOG(ERROR) << "Unable to create tf.lite interpreter, model is null.";
106 return nullptr;
107 }
108
109 const tflite::Model *tflite_model =
110 tflite::GetModel(model->tflite_model()->Data());
111 if (tflite_model == nullptr) {
112 TC3_LOG(ERROR) << "Unable to create tf.lite interpreter, model is null.";
113 return nullptr;
114 }
115
116 std::unique_ptr<tflite::OpResolver> resolver =
117 BuildOpResolver([](tflite::MutableOpResolver *mutable_resolver) {
118 mutable_resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
119 ::tflite::ops::builtin::Register_SHAPE());
120 mutable_resolver->AddBuiltin(::tflite::BuiltinOperator_RANGE,
121 ::tflite::ops::builtin::Register_RANGE());
122 mutable_resolver->AddBuiltin(
123 ::tflite::BuiltinOperator_ARG_MAX,
124 ::tflite::ops::builtin::Register_ARG_MAX());
125 mutable_resolver->AddBuiltin(
126 ::tflite::BuiltinOperator_EXPAND_DIMS,
127 ::tflite::ops::builtin::Register_EXPAND_DIMS());
128 mutable_resolver->AddCustom(
129 "LayerNorm", ::seq_flow_lite::ops::custom::Register_LAYER_NORM());
130 });
131
132 std::unique_ptr<tflite::Interpreter> tflite_interpreter;
133 tflite::InterpreterBuilder(tflite_model, *resolver,
134 nullptr)(&tflite_interpreter);
135 if (tflite_interpreter == nullptr) {
136 TC3_LOG(ERROR) << "Unable to create tf.lite interpreter.";
137 return nullptr;
138 }
139 return tflite_interpreter;
140 }
141
FindSpecialWordpieceIds(const std::unique_ptr<BertTokenizer> & tokenizer,int * cls_id,int * sep_id,int * period_id,int * unknown_id)142 bool FindSpecialWordpieceIds(const std::unique_ptr<BertTokenizer> &tokenizer,
143 int *cls_id, int *sep_id, int *period_id,
144 int *unknown_id) {
145 if (!tokenizer->LookupId("[CLS]", cls_id)) {
146 TC3_LOG(ERROR) << "Couldn't find [CLS] wordpiece.";
147 return false;
148 }
149 if (!tokenizer->LookupId("[SEP]", sep_id)) {
150 TC3_LOG(ERROR) << "Couldn't find [SEP] wordpiece.";
151 return false;
152 }
153 if (!tokenizer->LookupId(".", period_id)) {
154 TC3_LOG(ERROR) << "Couldn't find [.] wordpiece.";
155 return false;
156 }
157 if (!tokenizer->LookupId("[UNK]", unknown_id)) {
158 TC3_LOG(ERROR) << "Couldn't find [UNK] wordpiece.";
159 return false;
160 }
161 return true;
162 }
163 // WARNING: This tokenizer is not exactly the one the model was trained with
164 // so there might be nuances.
CreateTokenizer(const PodNerModel * model)165 std::unique_ptr<BertTokenizer> CreateTokenizer(const PodNerModel *model) {
166 TC3_CHECK(model != nullptr);
167 if (model->word_piece_vocab() == nullptr) {
168 TC3_LOG(ERROR)
169 << "Unable to create tokenizer, model or word_pieces is null.";
170 return nullptr;
171 }
172
173 return std::unique_ptr<BertTokenizer>(new BertTokenizer(
174 reinterpret_cast<const char *>(model->word_piece_vocab()->Data()),
175 model->word_piece_vocab()->size()));
176 }
177
178 } // namespace
179
Create(const PodNerModel * model,const UniLib & unilib)180 std::unique_ptr<PodNerAnnotator> PodNerAnnotator::Create(
181 const PodNerModel *model, const UniLib &unilib) {
182 if (model == nullptr) {
183 TC3_LOG(ERROR) << "Create received null model.";
184 return nullptr;
185 }
186
187 std::unique_ptr<BertTokenizer> tokenizer = CreateTokenizer(model);
188 if (tokenizer == nullptr) {
189 return nullptr;
190 }
191
192 int cls_id, sep_id, period_id, unknown_wordpiece_id;
193 if (!FindSpecialWordpieceIds(tokenizer, &cls_id, &sep_id, &period_id,
194 &unknown_wordpiece_id)) {
195 return nullptr;
196 }
197
198 std::unique_ptr<PodNerAnnotator> annotator(new PodNerAnnotator(unilib));
199 annotator->tokenizer_ = std::move(tokenizer);
200 annotator->lowercase_input_ = model->lowercase_input();
201 annotator->logits_index_in_output_tensor_ =
202 model->logits_index_in_output_tensor();
203 annotator->append_final_period_ = model->append_final_period();
204 if (model->labels() && model->labels()->size() > 0 && model->collections() &&
205 model->collections()->size() > 0) {
206 annotator->labels_.clear();
207 for (const PodNerModel_::Label *label : *model->labels()) {
208 annotator->labels_.emplace_back();
209 annotator->labels_.back().boise_type = label->boise_type();
210 annotator->labels_.back().mention_type = label->mention_type();
211 annotator->labels_.back().collection_id = label->collection_id();
212 }
213 for (const PodNerModel_::Collection *collection : *model->collections()) {
214 annotator->collections_.emplace_back();
215 annotator->collections_.back().name = collection->name()->str();
216 annotator->collections_.back().single_token_priority_score =
217 collection->single_token_priority_score();
218 annotator->collections_.back().multi_token_priority_score =
219 collection->multi_token_priority_score();
220 }
221 } else {
222 FillDefaultLabelsAndCollections(
223 model->priority_score(), &annotator->labels_, &annotator->collections_);
224 }
225 int max_num_surrounding_wordpieces = model->append_final_period() ? 3 : 2;
226 annotator->max_num_effective_wordpieces_ =
227 model->max_num_wordpieces() - max_num_surrounding_wordpieces;
228 annotator->sliding_window_num_wordpieces_overlap_ =
229 model->sliding_window_num_wordpieces_overlap();
230 annotator->max_ratio_unknown_wordpieces_ =
231 model->max_ratio_unknown_wordpieces();
232 annotator->min_number_of_tokens_ = model->min_number_of_tokens();
233 annotator->min_number_of_wordpieces_ = model->min_number_of_wordpieces();
234 annotator->cls_wordpiece_id_ = cls_id;
235 annotator->sep_wordpiece_id_ = sep_id;
236 annotator->period_wordpiece_id_ = period_id;
237 annotator->unknown_wordpiece_id_ = unknown_wordpiece_id;
238 annotator->model_ = model;
239
240 return annotator;
241 }
242
ReadResultsFromInterpreter(tflite::Interpreter & interpreter) const243 std::vector<LabelT> PodNerAnnotator::ReadResultsFromInterpreter(
244 tflite::Interpreter &interpreter) const {
245 TfLiteTensor *output =
246 interpreter.tensor(interpreter.outputs()[logits_index_in_output_tensor_]);
247 TC3_CHECK_EQ(output->dims->size, 3);
248 TC3_CHECK_EQ(output->dims->data[0], 1);
249 TC3_CHECK_EQ(output->dims->data[2], labels_.size());
250 std::vector<LabelT> return_value(output->dims->data[1]);
251 std::vector<float> probs(output->dims->data[1]);
252 for (int step = 0, index = 0; step < output->dims->data[1]; ++step) {
253 float max_prob = 0.0f;
254 int max_index = 0;
255 for (int cindex = 0; cindex < output->dims->data[2]; ++cindex) {
256 const float probability =
257 ::seq_flow_lite::PodDequantize(*output, index++);
258 if (probability > max_prob) {
259 max_prob = probability;
260 max_index = cindex;
261 }
262 }
263 return_value[step] = labels_[max_index];
264 probs[step] = max_prob;
265 }
266 return return_value;
267 }
268
ExecuteModel(const VectorSpan<int> & wordpiece_indices,const VectorSpan<int32_t> & token_starts,const VectorSpan<Token> & tokens) const269 std::vector<LabelT> PodNerAnnotator::ExecuteModel(
270 const VectorSpan<int> &wordpiece_indices,
271 const VectorSpan<int32_t> &token_starts,
272 const VectorSpan<Token> &tokens) const {
273 // Check that there are not more input indices than supported.
274 if (wordpiece_indices.size() > max_num_effective_wordpieces_) {
275 TC3_LOG(ERROR) << "More than " << max_num_effective_wordpieces_
276 << " indices passed to POD NER model.";
277 return {};
278 }
279 if (wordpiece_indices.size() <= 0 || token_starts.size() <= 0 ||
280 tokens.size() <= 0) {
281 TC3_LOG(ERROR) << "ExecuteModel received illegal input, #wordpiece_indices="
282 << wordpiece_indices.size()
283 << " #token_starts=" << token_starts.size()
284 << " #tokens=" << tokens.size();
285 return {};
286 }
287
288 // For the CLS (at the beginning) and SEP (at the end) wordpieces.
289 int num_additional_wordpieces = 2;
290 bool should_append_final_period = false;
291 // Optionally add a final period wordpiece if the final token is not
292 // already punctuation. This can improve performance for models trained on
293 // data mostly ending in sentence-final punctuation.
294 const std::string &last_token = (tokens.end() - 1)->value;
295 if (append_final_period_ &&
296 (last_token.size() != 1 || !unilib_.IsPunctuation(last_token.at(0)))) {
297 should_append_final_period = true;
298 num_additional_wordpieces++;
299 }
300
301 // Interpreter needs to be created for each inference call separately,
302 // otherwise the class is not thread-safe.
303 std::unique_ptr<tflite::Interpreter> interpreter = CreateInterpreter(model_);
304 if (interpreter == nullptr) {
305 TC3_LOG(ERROR) << "Couldn't create Interpreter.";
306 return {};
307 }
308
309 TfLiteStatus status;
310 status = interpreter->ResizeInputTensor(
311 interpreter->inputs()[0],
312 {1, wordpiece_indices.size() + num_additional_wordpieces});
313 TC3_CHECK_EQ(status, kTfLiteOk);
314 status = interpreter->ResizeInputTensor(interpreter->inputs()[1],
315 {1, token_starts.size()});
316 TC3_CHECK_EQ(status, kTfLiteOk);
317
318 status = interpreter->AllocateTensors();
319 TC3_CHECK_EQ(status, kTfLiteOk);
320
321 TfLiteTensor *tensor = interpreter->tensor(interpreter->inputs()[0]);
322 int wordpiece_tensor_index = 0;
323 tensor->data.i32[wordpiece_tensor_index++] = cls_wordpiece_id_;
324 for (int wordpiece_index : wordpiece_indices) {
325 tensor->data.i32[wordpiece_tensor_index++] = wordpiece_index;
326 }
327
328 if (should_append_final_period) {
329 tensor->data.i32[wordpiece_tensor_index++] = period_wordpiece_id_;
330 }
331 tensor->data.i32[wordpiece_tensor_index++] = sep_wordpiece_id_;
332
333 tensor = interpreter->tensor(interpreter->inputs()[1]);
334 for (int i = 0; i < token_starts.size(); ++i) {
335 // Need to add one because of the starting CLS wordpiece and reduce the
336 // offset from the first wordpiece.
337 tensor->data.i32[i] = token_starts[i] + 1 - token_starts[0];
338 }
339
340 status = interpreter->Invoke();
341 TC3_CHECK_EQ(status, kTfLiteOk);
342
343 return ReadResultsFromInterpreter(*interpreter);
344 }
345
PrepareText(const UnicodeText & text_unicode,std::vector<int32_t> * wordpiece_indices,std::vector<int32_t> * token_starts,std::vector<Token> * tokens) const346 bool PodNerAnnotator::PrepareText(const UnicodeText &text_unicode,
347 std::vector<int32_t> *wordpiece_indices,
348 std::vector<int32_t> *token_starts,
349 std::vector<Token> *tokens) const {
350 *tokens = TokenizeOnWhiteSpacePunctuationAndChineseLetter(
351 text_unicode.ToUTF8String());
352 tokens->erase(std::remove_if(tokens->begin(), tokens->end(),
353 [](const Token &token) {
354 return token.start == token.end;
355 }),
356 tokens->end());
357
358 for (const Token &token : *tokens) {
359 const std::string token_text =
360 lowercase_input_ ? unilib_
361 .ToLowerText(UTF8ToUnicodeText(
362 token.value, /*do_copy=*/false))
363 .ToUTF8String()
364 : token.value;
365
366 const TokenizerResult wordpiece_tokenization =
367 tokenizer_->TokenizeSingleToken(token_text);
368
369 std::vector<int> wordpiece_ids;
370 for (const std::string &wordpiece : wordpiece_tokenization.subwords) {
371 if (!tokenizer_->LookupId(wordpiece, &(wordpiece_ids.emplace_back()))) {
372 TC3_LOG(ERROR) << "Couldn't find wordpiece " << wordpiece;
373 return false;
374 }
375 }
376
377 if (wordpiece_ids.empty()) {
378 TC3_LOG(ERROR) << "wordpiece_ids.empty()";
379 return false;
380 }
381 token_starts->push_back(wordpiece_indices->size());
382 for (const int64 wordpiece_id : wordpiece_ids) {
383 wordpiece_indices->push_back(wordpiece_id);
384 }
385 }
386
387 return true;
388 }
389
Annotate(const UnicodeText & context,std::vector<AnnotatedSpan> * results) const390 bool PodNerAnnotator::Annotate(const UnicodeText &context,
391 std::vector<AnnotatedSpan> *results) const {
392 return AnnotateAroundSpanOfInterest(context, {0, context.size_codepoints()},
393 results);
394 }
395
AnnotateAroundSpanOfInterest(const UnicodeText & context,const CodepointSpan & span_of_interest,std::vector<AnnotatedSpan> * results) const396 bool PodNerAnnotator::AnnotateAroundSpanOfInterest(
397 const UnicodeText &context, const CodepointSpan &span_of_interest,
398 std::vector<AnnotatedSpan> *results) const {
399 TC3_CHECK(results != nullptr);
400
401 std::vector<int32_t> wordpiece_indices;
402 std::vector<int32_t> token_starts;
403 std::vector<Token> tokens;
404 if (!PrepareText(context, &wordpiece_indices, &token_starts, &tokens)) {
405 TC3_LOG(ERROR) << "PodNerAnnotator PrepareText(...) failed.";
406 return false;
407 }
408 const int unknown_wordpieces_count =
409 std::count(wordpiece_indices.begin(), wordpiece_indices.end(),
410 unknown_wordpiece_id_);
411 if (tokens.empty() || tokens.size() < min_number_of_tokens_ ||
412 wordpiece_indices.size() < min_number_of_wordpieces_ ||
413 (static_cast<float>(unknown_wordpieces_count) /
414 wordpiece_indices.size()) > max_ratio_unknown_wordpieces_) {
415 return true;
416 }
417
418 std::vector<LabelT> labels;
419 int first_token_index_entire_window = 0;
420
421 WindowGenerator window_generator(
422 wordpiece_indices, token_starts, tokens, max_num_effective_wordpieces_,
423 sliding_window_num_wordpieces_overlap_, span_of_interest);
424 while (!window_generator.Done()) {
425 VectorSpan<int32_t> cur_wordpiece_indices;
426 VectorSpan<int32_t> cur_token_starts;
427 VectorSpan<Token> cur_tokens;
428 if (!window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
429 &cur_tokens) ||
430 cur_tokens.size() <= 0 || cur_token_starts.size() <= 0 ||
431 cur_wordpiece_indices.size() <= 0) {
432 return false;
433 }
434 std::vector<LabelT> new_labels =
435 ExecuteModel(cur_wordpiece_indices, cur_token_starts, cur_tokens);
436 if (labels.empty()) { // First loop.
437 first_token_index_entire_window = cur_tokens.begin() - tokens.begin();
438 }
439 if (!MergeLabelsIntoLeftSequence(
440 /*labels_right=*/new_labels,
441 /*index_first_right_tag_in_left=*/cur_tokens.begin() -
442 tokens.begin() - first_token_index_entire_window,
443 /*labels_left=*/&labels)) {
444 return false;
445 }
446 }
447
448 if (labels.empty()) {
449 return false;
450 }
451 ConvertTagsToAnnotatedSpans(
452 VectorSpan<Token>(tokens.begin() + first_token_index_entire_window,
453 tokens.end()),
454 labels, collections_, {PodNerModel_::Label_::MentionType_NAM},
455 /*relaxed_inside_label_matching=*/false,
456 /*relaxed_mention_type_matching=*/false, results);
457
458 return true;
459 }
460
SuggestSelection(const UnicodeText & context,CodepointSpan click,AnnotatedSpan * result) const461 bool PodNerAnnotator::SuggestSelection(const UnicodeText &context,
462 CodepointSpan click,
463 AnnotatedSpan *result) const {
464 TC3_VLOG(INFO) << "POD NER SuggestSelection " << click;
465 std::vector<AnnotatedSpan> annotations;
466 if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
467 TC3_VLOG(INFO) << "POD NER SuggestSelection: Annotate error. Returning: "
468 << click;
469 *result = {};
470 return false;
471 }
472
473 for (const AnnotatedSpan &annotation : annotations) {
474 TC3_VLOG(INFO) << "POD NER SuggestSelection: " << annotation;
475 if (annotation.span.first <= click.first &&
476 annotation.span.second >= click.second) {
477 TC3_VLOG(INFO) << "POD NER SuggestSelection: Accepted.";
478 *result = annotation;
479 return true;
480 }
481 }
482
483 TC3_VLOG(INFO)
484 << "POD NER SuggestSelection: No annotation matched click. Returning: "
485 << click;
486 *result = {};
487 return false;
488 }
489
ClassifyText(const UnicodeText & context,CodepointSpan click,ClassificationResult * result) const490 bool PodNerAnnotator::ClassifyText(const UnicodeText &context,
491 CodepointSpan click,
492 ClassificationResult *result) const {
493 TC3_VLOG(INFO) << "POD NER ClassifyText " << click;
494 std::vector<AnnotatedSpan> annotations;
495 if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
496 return false;
497 }
498
499 for (const AnnotatedSpan &annotation : annotations) {
500 if (annotation.span.first <= click.first &&
501 annotation.span.second >= click.second) {
502 if (annotation.classification.empty()) {
503 return false;
504 }
505 *result = annotation.classification[0];
506 return true;
507 }
508 }
509 return false;
510 }
511
GetSupportedCollections() const512 std::vector<std::string> PodNerAnnotator::GetSupportedCollections() const {
513 std::vector<std::string> result;
514 for (const PodNerModel_::CollectionT &collection : collections_) {
515 result.push_back(collection.name);
516 }
517 return result;
518 }
519
520 } // namespace libtextclassifier3
521