1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/blacklist.h"
18
19 #include "utils/tflite/blacklist_base.h"
20 #include "utils/tflite/skipgram_finder.h"
21 #include "flatbuffers/flexbuffers.h"
22
23 namespace tflite {
24 namespace ops {
25 namespace custom {
26
27 namespace libtextclassifier3 {
28 namespace blacklist {
29
30 // Generates prediction vectors for input strings using a skipgram blacklist.
31 // This uses the framework in `blacklist_base.h`, with the implementation detail
32 // that the input is a string tensor of messages and the terms are skipgrams.
33 class BlacklistOp : public BlacklistOpBase {
34 public:
BlacklistOp(const flexbuffers::Map & custom_options)35 explicit BlacklistOp(const flexbuffers::Map& custom_options)
36 : BlacklistOpBase(custom_options),
37 skipgram_finder_(custom_options["max_skip_size"].AsInt32()),
38 input_(nullptr) {
39 auto blacklist = custom_options["blacklist"].AsTypedVector();
40 auto blacklist_category =
41 custom_options["blacklist_category"].AsTypedVector();
42 for (int i = 0; i < blacklist.size(); i++) {
43 int category = blacklist_category[i].AsInt32();
44 flexbuffers::String s = blacklist[i].AsString();
45 skipgram_finder_.AddSkipgram(std::string(s.c_str(), s.length()),
46 category);
47 }
48 }
49
InitializeInput(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus InitializeInput(TfLiteContext* context,
51 TfLiteNode* node) override {
52 input_ = &context->tensors[node->inputs->data[kInputMessage]];
53 return kTfLiteOk;
54 }
55
GetCategories(int i) const56 absl::flat_hash_set<int> GetCategories(int i) const override {
57 StringRef input = GetString(input_, i);
58 return skipgram_finder_.FindSkipgrams(std::string(input.str, input.len));
59 }
60
FinalizeInput()61 void FinalizeInput() override { input_ = nullptr; }
62
GetInputShape(TfLiteContext * context,TfLiteNode * node)63 TfLiteIntArray* GetInputShape(TfLiteContext* context,
64 TfLiteNode* node) override {
65 return context->tensors[node->inputs->data[kInputMessage]].dims;
66 }
67
68 private:
69 ::libtextclassifier3::SkipgramFinder skipgram_finder_;
70 TfLiteTensor* input_;
71
72 static constexpr int kInputMessage = 0;
73 };
74
BlacklistOpInit(TfLiteContext * context,const char * buffer,size_t length)75 void* BlacklistOpInit(TfLiteContext* context, const char* buffer,
76 size_t length) {
77 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
78 return new BlacklistOp(flexbuffers::GetRoot(buffer_t, length).AsMap());
79 }
80
81 } // namespace blacklist
82
Register_BLACKLIST()83 TfLiteRegistration* Register_BLACKLIST() {
84 static TfLiteRegistration r = {libtextclassifier3::blacklist::BlacklistOpInit,
85 libtextclassifier3::blacklist::Free,
86 libtextclassifier3::blacklist::Resize,
87 libtextclassifier3::blacklist::Eval};
88 return &r;
89 }
90
91 } // namespace libtextclassifier3
92 } // namespace custom
93 } // namespace ops
94 } // namespace tflite
95