• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/pod_ner/pod-ner-impl.h"
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <ctime>
22 #include <iostream>
23 #include <memory>
24 #include <ostream>
25 #include <unordered_set>
26 #include <vector>
27 
28 #include "annotator/model_generated.h"
29 #include "annotator/pod_ner/utils.h"
30 #include "annotator/types.h"
31 #include "utils/base/logging.h"
32 #include "utils/bert_tokenizer.h"
33 #include "utils/tflite-model-executor.h"
34 #include "utils/tokenizer-utils.h"
35 #include "utils/utf8/unicodetext.h"
36 #include "absl/strings/ascii.h"
37 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
38 #include "tensorflow/lite/mutable_op_resolver.h"
39 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
40 #include "tensorflow_models/seq_flow_lite/tflite_ops/layer_norm.h"
41 #include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
42 
43 namespace libtextclassifier3 {
44 
45 using PodNerModel_::CollectionT;
46 using PodNerModel_::LabelT;
47 using ::tflite::support::text::tokenizer::TokenizerResult;
48 
49 namespace {
50 
51 using PodNerModel_::Label_::BoiseType;
52 using PodNerModel_::Label_::BoiseType_BEGIN;
53 using PodNerModel_::Label_::BoiseType_END;
54 using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
55 using PodNerModel_::Label_::BoiseType_O;
56 using PodNerModel_::Label_::BoiseType_SINGLE;
57 using PodNerModel_::Label_::MentionType;
58 using PodNerModel_::Label_::MentionType_NAM;
59 using PodNerModel_::Label_::MentionType_NOM;
60 using PodNerModel_::Label_::MentionType_UNDEFINED;
61 
EmplaceToLabelVector(BoiseType boise_type,MentionType mention_type,int collection_id,std::vector<LabelT> * labels)62 void EmplaceToLabelVector(BoiseType boise_type, MentionType mention_type,
63                           int collection_id, std::vector<LabelT> *labels) {
64   labels->emplace_back();
65   labels->back().boise_type = boise_type;
66   labels->back().mention_type = mention_type;
67   labels->back().collection_id = collection_id;
68 }
69 
FillDefaultLabelsAndCollections(float default_priority,std::vector<LabelT> * labels,std::vector<CollectionT> * collections)70 void FillDefaultLabelsAndCollections(float default_priority,
71                                      std::vector<LabelT> *labels,
72                                      std::vector<CollectionT> *collections) {
73   std::vector<std::string> collection_names = {
74       "art",          "consumer_good", "event",  "location",
75       "organization", "ner_entity",    "person", "undefined"};
76   collections->clear();
77   for (const std::string &collection_name : collection_names) {
78     collections->emplace_back();
79     collections->back().name = collection_name;
80     collections->back().single_token_priority_score = default_priority;
81     collections->back().multi_token_priority_score = default_priority;
82   }
83 
84   labels->clear();
85   for (auto boise_type :
86        {BoiseType_BEGIN, BoiseType_END, BoiseType_INTERMEDIATE}) {
87     for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
88       for (int i = 0; i < collections->size() - 1; ++i) {  // skip undefined
89         EmplaceToLabelVector(boise_type, mention_type, i, labels);
90       }
91     }
92   }
93   EmplaceToLabelVector(BoiseType_O, MentionType_UNDEFINED, 7, labels);
94   for (auto mention_type : {MentionType_NAM, MentionType_NOM}) {
95     for (int i = 0; i < collections->size() - 1; ++i) {  // skip undefined
96       EmplaceToLabelVector(BoiseType_SINGLE, mention_type, i, labels);
97     }
98   }
99 }
100 
CreateInterpreter(const PodNerModel * model)101 std::unique_ptr<tflite::Interpreter> CreateInterpreter(
102     const PodNerModel *model) {
103   TC3_CHECK(model != nullptr);
104   if (model->tflite_model() == nullptr) {
105     TC3_LOG(ERROR) << "Unable to create tf.lite interpreter, model is null.";
106     return nullptr;
107   }
108 
109   const tflite::Model *tflite_model =
110       tflite::GetModel(model->tflite_model()->Data());
111   if (tflite_model == nullptr) {
112     TC3_LOG(ERROR) << "Unable to create tf.lite interpreter, model is null.";
113     return nullptr;
114   }
115 
116   std::unique_ptr<tflite::OpResolver> resolver =
117       BuildOpResolver([](tflite::MutableOpResolver *mutable_resolver) {
118         mutable_resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
119                                      ::tflite::ops::builtin::Register_SHAPE());
120         mutable_resolver->AddBuiltin(::tflite::BuiltinOperator_RANGE,
121                                      ::tflite::ops::builtin::Register_RANGE());
122         mutable_resolver->AddBuiltin(
123             ::tflite::BuiltinOperator_ARG_MAX,
124             ::tflite::ops::builtin::Register_ARG_MAX());
125         mutable_resolver->AddBuiltin(
126             ::tflite::BuiltinOperator_EXPAND_DIMS,
127             ::tflite::ops::builtin::Register_EXPAND_DIMS());
128         mutable_resolver->AddCustom(
129             "LayerNorm", ::seq_flow_lite::ops::custom::Register_LAYER_NORM());
130       });
131 
132   std::unique_ptr<tflite::Interpreter> tflite_interpreter;
133   tflite::InterpreterBuilder(tflite_model, *resolver,
134                              nullptr)(&tflite_interpreter);
135   if (tflite_interpreter == nullptr) {
136     TC3_LOG(ERROR) << "Unable to create tf.lite interpreter.";
137     return nullptr;
138   }
139   return tflite_interpreter;
140 }
141 
FindSpecialWordpieceIds(const std::unique_ptr<BertTokenizer> & tokenizer,int * cls_id,int * sep_id,int * period_id,int * unknown_id)142 bool FindSpecialWordpieceIds(const std::unique_ptr<BertTokenizer> &tokenizer,
143                              int *cls_id, int *sep_id, int *period_id,
144                              int *unknown_id) {
145   if (!tokenizer->LookupId("[CLS]", cls_id)) {
146     TC3_LOG(ERROR) << "Couldn't find [CLS] wordpiece.";
147     return false;
148   }
149   if (!tokenizer->LookupId("[SEP]", sep_id)) {
150     TC3_LOG(ERROR) << "Couldn't find [SEP] wordpiece.";
151     return false;
152   }
153   if (!tokenizer->LookupId(".", period_id)) {
154     TC3_LOG(ERROR) << "Couldn't find [.] wordpiece.";
155     return false;
156   }
157   if (!tokenizer->LookupId("[UNK]", unknown_id)) {
158     TC3_LOG(ERROR) << "Couldn't find [UNK] wordpiece.";
159     return false;
160   }
161   return true;
162 }
163 // WARNING: This tokenizer is not exactly the one the model was trained with
164 // so there might be nuances.
CreateTokenizer(const PodNerModel * model)165 std::unique_ptr<BertTokenizer> CreateTokenizer(const PodNerModel *model) {
166   TC3_CHECK(model != nullptr);
167   if (model->word_piece_vocab() == nullptr) {
168     TC3_LOG(ERROR)
169         << "Unable to create tokenizer, model or word_pieces is null.";
170     return nullptr;
171   }
172 
173   return std::unique_ptr<BertTokenizer>(new BertTokenizer(
174       reinterpret_cast<const char *>(model->word_piece_vocab()->Data()),
175       model->word_piece_vocab()->size()));
176 }
177 
178 }  // namespace
179 
Create(const PodNerModel * model,const UniLib & unilib)180 std::unique_ptr<PodNerAnnotator> PodNerAnnotator::Create(
181     const PodNerModel *model, const UniLib &unilib) {
182   if (model == nullptr) {
183     TC3_LOG(ERROR) << "Create received null model.";
184     return nullptr;
185   }
186 
187   std::unique_ptr<BertTokenizer> tokenizer = CreateTokenizer(model);
188   if (tokenizer == nullptr) {
189     return nullptr;
190   }
191 
192   int cls_id, sep_id, period_id, unknown_wordpiece_id;
193   if (!FindSpecialWordpieceIds(tokenizer, &cls_id, &sep_id, &period_id,
194                                &unknown_wordpiece_id)) {
195     return nullptr;
196   }
197 
198   std::unique_ptr<PodNerAnnotator> annotator(new PodNerAnnotator(unilib));
199   annotator->tokenizer_ = std::move(tokenizer);
200   annotator->lowercase_input_ = model->lowercase_input();
201   annotator->logits_index_in_output_tensor_ =
202       model->logits_index_in_output_tensor();
203   annotator->append_final_period_ = model->append_final_period();
204   if (model->labels() && model->labels()->size() > 0 && model->collections() &&
205       model->collections()->size() > 0) {
206     annotator->labels_.clear();
207     for (const PodNerModel_::Label *label : *model->labels()) {
208       annotator->labels_.emplace_back();
209       annotator->labels_.back().boise_type = label->boise_type();
210       annotator->labels_.back().mention_type = label->mention_type();
211       annotator->labels_.back().collection_id = label->collection_id();
212     }
213     for (const PodNerModel_::Collection *collection : *model->collections()) {
214       annotator->collections_.emplace_back();
215       annotator->collections_.back().name = collection->name()->str();
216       annotator->collections_.back().single_token_priority_score =
217           collection->single_token_priority_score();
218       annotator->collections_.back().multi_token_priority_score =
219           collection->multi_token_priority_score();
220     }
221   } else {
222     FillDefaultLabelsAndCollections(
223         model->priority_score(), &annotator->labels_, &annotator->collections_);
224   }
225   int max_num_surrounding_wordpieces = model->append_final_period() ? 3 : 2;
226   annotator->max_num_effective_wordpieces_ =
227       model->max_num_wordpieces() - max_num_surrounding_wordpieces;
228   annotator->sliding_window_num_wordpieces_overlap_ =
229       model->sliding_window_num_wordpieces_overlap();
230   annotator->max_ratio_unknown_wordpieces_ =
231       model->max_ratio_unknown_wordpieces();
232   annotator->min_number_of_tokens_ = model->min_number_of_tokens();
233   annotator->min_number_of_wordpieces_ = model->min_number_of_wordpieces();
234   annotator->cls_wordpiece_id_ = cls_id;
235   annotator->sep_wordpiece_id_ = sep_id;
236   annotator->period_wordpiece_id_ = period_id;
237   annotator->unknown_wordpiece_id_ = unknown_wordpiece_id;
238   annotator->model_ = model;
239 
240   return annotator;
241 }
242 
ReadResultsFromInterpreter(tflite::Interpreter & interpreter) const243 std::vector<LabelT> PodNerAnnotator::ReadResultsFromInterpreter(
244     tflite::Interpreter &interpreter) const {
245   TfLiteTensor *output =
246       interpreter.tensor(interpreter.outputs()[logits_index_in_output_tensor_]);
247   TC3_CHECK_EQ(output->dims->size, 3);
248   TC3_CHECK_EQ(output->dims->data[0], 1);
249   TC3_CHECK_EQ(output->dims->data[2], labels_.size());
250   std::vector<LabelT> return_value(output->dims->data[1]);
251   std::vector<float> probs(output->dims->data[1]);
252   for (int step = 0, index = 0; step < output->dims->data[1]; ++step) {
253     float max_prob = 0.0f;
254     int max_index = 0;
255     for (int cindex = 0; cindex < output->dims->data[2]; ++cindex) {
256       const float probability =
257           ::seq_flow_lite::PodDequantize(*output, index++);
258       if (probability > max_prob) {
259         max_prob = probability;
260         max_index = cindex;
261       }
262     }
263     return_value[step] = labels_[max_index];
264     probs[step] = max_prob;
265   }
266   return return_value;
267 }
268 
ExecuteModel(const VectorSpan<int> & wordpiece_indices,const VectorSpan<int32_t> & token_starts,const VectorSpan<Token> & tokens) const269 std::vector<LabelT> PodNerAnnotator::ExecuteModel(
270     const VectorSpan<int> &wordpiece_indices,
271     const VectorSpan<int32_t> &token_starts,
272     const VectorSpan<Token> &tokens) const {
273   // Check that there are not more input indices than supported.
274   if (wordpiece_indices.size() > max_num_effective_wordpieces_) {
275     TC3_LOG(ERROR) << "More than " << max_num_effective_wordpieces_
276                    << " indices passed to POD NER model.";
277     return {};
278   }
279   if (wordpiece_indices.size() <= 0 || token_starts.size() <= 0 ||
280       tokens.size() <= 0) {
281     TC3_LOG(ERROR) << "ExecuteModel received illegal input, #wordpiece_indices="
282                    << wordpiece_indices.size()
283                    << " #token_starts=" << token_starts.size()
284                    << " #tokens=" << tokens.size();
285     return {};
286   }
287 
288   // For the CLS (at the beginning) and SEP (at the end) wordpieces.
289   int num_additional_wordpieces = 2;
290   bool should_append_final_period = false;
291   // Optionally add a final period wordpiece if the final token is not
292   // already punctuation. This can improve performance for models trained on
293   // data mostly ending in sentence-final punctuation.
294   const std::string &last_token = (tokens.end() - 1)->value;
295   if (append_final_period_ &&
296       (last_token.size() != 1 || !unilib_.IsPunctuation(last_token.at(0)))) {
297     should_append_final_period = true;
298     num_additional_wordpieces++;
299   }
300 
301   // Interpreter needs to be created for each inference call separately,
302   // otherwise the class is not thread-safe.
303   std::unique_ptr<tflite::Interpreter> interpreter = CreateInterpreter(model_);
304   if (interpreter == nullptr) {
305     TC3_LOG(ERROR) << "Couldn't create Interpreter.";
306     return {};
307   }
308 
309   TfLiteStatus status;
310   status = interpreter->ResizeInputTensor(
311       interpreter->inputs()[0],
312       {1, wordpiece_indices.size() + num_additional_wordpieces});
313   TC3_CHECK_EQ(status, kTfLiteOk);
314   status = interpreter->ResizeInputTensor(interpreter->inputs()[1],
315                                           {1, token_starts.size()});
316   TC3_CHECK_EQ(status, kTfLiteOk);
317 
318   status = interpreter->AllocateTensors();
319   TC3_CHECK_EQ(status, kTfLiteOk);
320 
321   TfLiteTensor *tensor = interpreter->tensor(interpreter->inputs()[0]);
322   int wordpiece_tensor_index = 0;
323   tensor->data.i32[wordpiece_tensor_index++] = cls_wordpiece_id_;
324   for (int wordpiece_index : wordpiece_indices) {
325     tensor->data.i32[wordpiece_tensor_index++] = wordpiece_index;
326   }
327 
328   if (should_append_final_period) {
329     tensor->data.i32[wordpiece_tensor_index++] = period_wordpiece_id_;
330   }
331   tensor->data.i32[wordpiece_tensor_index++] = sep_wordpiece_id_;
332 
333   tensor = interpreter->tensor(interpreter->inputs()[1]);
334   for (int i = 0; i < token_starts.size(); ++i) {
335     // Need to add one because of the starting CLS wordpiece and reduce the
336     // offset from the first wordpiece.
337     tensor->data.i32[i] = token_starts[i] + 1 - token_starts[0];
338   }
339 
340   status = interpreter->Invoke();
341   TC3_CHECK_EQ(status, kTfLiteOk);
342 
343   return ReadResultsFromInterpreter(*interpreter);
344 }
345 
PrepareText(const UnicodeText & text_unicode,std::vector<int32_t> * wordpiece_indices,std::vector<int32_t> * token_starts,std::vector<Token> * tokens) const346 bool PodNerAnnotator::PrepareText(const UnicodeText &text_unicode,
347                                   std::vector<int32_t> *wordpiece_indices,
348                                   std::vector<int32_t> *token_starts,
349                                   std::vector<Token> *tokens) const {
350   *tokens = TokenizeOnWhiteSpacePunctuationAndChineseLetter(
351       text_unicode.ToUTF8String());
352   tokens->erase(std::remove_if(tokens->begin(), tokens->end(),
353                                [](const Token &token) {
354                                  return token.start == token.end;
355                                }),
356                 tokens->end());
357 
358   for (const Token &token : *tokens) {
359     const std::string token_text =
360         lowercase_input_ ? unilib_
361                                .ToLowerText(UTF8ToUnicodeText(
362                                    token.value, /*do_copy=*/false))
363                                .ToUTF8String()
364                          : token.value;
365 
366     const TokenizerResult wordpiece_tokenization =
367         tokenizer_->TokenizeSingleToken(token_text);
368 
369     std::vector<int> wordpiece_ids;
370     for (const std::string &wordpiece : wordpiece_tokenization.subwords) {
371       if (!tokenizer_->LookupId(wordpiece, &(wordpiece_ids.emplace_back()))) {
372         TC3_LOG(ERROR) << "Couldn't find wordpiece " << wordpiece;
373         return false;
374       }
375     }
376 
377     if (wordpiece_ids.empty()) {
378       TC3_LOG(ERROR) << "wordpiece_ids.empty()";
379       return false;
380     }
381     token_starts->push_back(wordpiece_indices->size());
382     for (const int64 wordpiece_id : wordpiece_ids) {
383       wordpiece_indices->push_back(wordpiece_id);
384     }
385   }
386 
387   return true;
388 }
389 
Annotate(const UnicodeText & context,std::vector<AnnotatedSpan> * results) const390 bool PodNerAnnotator::Annotate(const UnicodeText &context,
391                                std::vector<AnnotatedSpan> *results) const {
392   return AnnotateAroundSpanOfInterest(context, {0, context.size_codepoints()},
393                                       results);
394 }
395 
AnnotateAroundSpanOfInterest(const UnicodeText & context,const CodepointSpan & span_of_interest,std::vector<AnnotatedSpan> * results) const396 bool PodNerAnnotator::AnnotateAroundSpanOfInterest(
397     const UnicodeText &context, const CodepointSpan &span_of_interest,
398     std::vector<AnnotatedSpan> *results) const {
399   TC3_CHECK(results != nullptr);
400 
401   std::vector<int32_t> wordpiece_indices;
402   std::vector<int32_t> token_starts;
403   std::vector<Token> tokens;
404   if (!PrepareText(context, &wordpiece_indices, &token_starts, &tokens)) {
405     TC3_LOG(ERROR) << "PodNerAnnotator PrepareText(...) failed.";
406     return false;
407   }
408   const int unknown_wordpieces_count =
409       std::count(wordpiece_indices.begin(), wordpiece_indices.end(),
410                  unknown_wordpiece_id_);
411   if (tokens.empty() || tokens.size() < min_number_of_tokens_ ||
412       wordpiece_indices.size() < min_number_of_wordpieces_ ||
413       (static_cast<float>(unknown_wordpieces_count) /
414        wordpiece_indices.size()) > max_ratio_unknown_wordpieces_) {
415     return true;
416   }
417 
418   std::vector<LabelT> labels;
419   int first_token_index_entire_window = 0;
420 
421   WindowGenerator window_generator(
422       wordpiece_indices, token_starts, tokens, max_num_effective_wordpieces_,
423       sliding_window_num_wordpieces_overlap_, span_of_interest);
424   while (!window_generator.Done()) {
425     VectorSpan<int32_t> cur_wordpiece_indices;
426     VectorSpan<int32_t> cur_token_starts;
427     VectorSpan<Token> cur_tokens;
428     if (!window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
429                                &cur_tokens) ||
430         cur_tokens.size() <= 0 || cur_token_starts.size() <= 0 ||
431         cur_wordpiece_indices.size() <= 0) {
432       return false;
433     }
434     std::vector<LabelT> new_labels =
435         ExecuteModel(cur_wordpiece_indices, cur_token_starts, cur_tokens);
436     if (labels.empty()) {  // First loop.
437       first_token_index_entire_window = cur_tokens.begin() - tokens.begin();
438     }
439     if (!MergeLabelsIntoLeftSequence(
440             /*labels_right=*/new_labels,
441             /*index_first_right_tag_in_left=*/cur_tokens.begin() -
442                 tokens.begin() - first_token_index_entire_window,
443             /*labels_left=*/&labels)) {
444       return false;
445     }
446   }
447 
448   if (labels.empty()) {
449     return false;
450   }
451   ConvertTagsToAnnotatedSpans(
452       VectorSpan<Token>(tokens.begin() + first_token_index_entire_window,
453                         tokens.end()),
454       labels, collections_, {PodNerModel_::Label_::MentionType_NAM},
455       /*relaxed_inside_label_matching=*/false,
456       /*relaxed_mention_type_matching=*/false, results);
457 
458   return true;
459 }
460 
SuggestSelection(const UnicodeText & context,CodepointSpan click,AnnotatedSpan * result) const461 bool PodNerAnnotator::SuggestSelection(const UnicodeText &context,
462                                        CodepointSpan click,
463                                        AnnotatedSpan *result) const {
464   TC3_VLOG(INFO) << "POD NER SuggestSelection " << click;
465   std::vector<AnnotatedSpan> annotations;
466   if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
467     TC3_VLOG(INFO) << "POD NER SuggestSelection: Annotate error. Returning: "
468                    << click;
469     *result = {};
470     return false;
471   }
472 
473   for (const AnnotatedSpan &annotation : annotations) {
474     TC3_VLOG(INFO) << "POD NER SuggestSelection: " << annotation;
475     if (annotation.span.first <= click.first &&
476         annotation.span.second >= click.second) {
477       TC3_VLOG(INFO) << "POD NER SuggestSelection: Accepted.";
478       *result = annotation;
479       return true;
480     }
481   }
482 
483   TC3_VLOG(INFO)
484       << "POD NER SuggestSelection: No annotation matched click. Returning: "
485       << click;
486   *result = {};
487   return false;
488 }
489 
ClassifyText(const UnicodeText & context,CodepointSpan click,ClassificationResult * result) const490 bool PodNerAnnotator::ClassifyText(const UnicodeText &context,
491                                    CodepointSpan click,
492                                    ClassificationResult *result) const {
493   TC3_VLOG(INFO) << "POD NER ClassifyText " << click;
494   std::vector<AnnotatedSpan> annotations;
495   if (!AnnotateAroundSpanOfInterest(context, click, &annotations)) {
496     return false;
497   }
498 
499   for (const AnnotatedSpan &annotation : annotations) {
500     if (annotation.span.first <= click.first &&
501         annotation.span.second >= click.second) {
502       if (annotation.classification.empty()) {
503         return false;
504       }
505       *result = annotation.classification[0];
506       return true;
507     }
508   }
509   return false;
510 }
511 
GetSupportedCollections() const512 std::vector<std::string> PodNerAnnotator::GetSupportedCollections() const {
513   std::vector<std::string> result;
514   for (const PodNerModel_::CollectionT &collection : collections_) {
515     result.push_back(collection.name);
516   }
517   return result;
518 }
519 
520 }  // namespace libtextclassifier3
521