1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_ 19 20 #include "absl/container/flat_hash_set.h" 21 #include "flatbuffers/flexbuffers.h" 22 #include "tensorflow/lite/context.h" 23 24 namespace tflite { 25 namespace ops { 26 namespace custom { 27 namespace libtextclassifier3 { 28 namespace blacklist { 29 30 /* 31 * A framework for writing ops that generates prediction vectors using a 32 * blacklist. 33 * 34 * Input is defined by the specific implementation. 35 * 36 * Attributes: 37 * blacklist: string[n] 38 * Terms in the blacklist. 39 * blacklist_category: int[n] 40 * Category for each term in the blacklist. Each category must be in 41 * [0, categories). 42 * categories: int[] 43 * Total number of categories. 44 * negative_categories: int[] 45 * Total number of negative categories. 46 * 47 * Output: 48 * tensor[0]: Category indicators for each message, float[..., categories] 49 * 50 */ 51 52 class BlacklistOpBase { 53 public: BlacklistOpBase(const flexbuffers::Map & custom_options)54 explicit BlacklistOpBase(const flexbuffers::Map& custom_options) 55 : categories_(custom_options["categories"].AsInt32()), 56 negative_categories_(custom_options["negative_categories"].AsInt32()) {} 57 ~BlacklistOpBase()58 virtual ~BlacklistOpBase() {} 59 categories()60 int categories() const { return categories_; } negative_categories()61 int negative_categories() const { return negative_categories_; } 62 63 virtual TfLiteStatus InitializeInput(TfLiteContext* context, 64 TfLiteNode* node) = 0; 65 virtual absl::flat_hash_set<int> GetCategories(int i) const = 0; 66 virtual void FinalizeInput() = 0; 67 68 // Returns the input shape. TfLiteIntArray is owned by the object. 69 virtual TfLiteIntArray* GetInputShape(TfLiteContext* context, 70 TfLiteNode* node) = 0; 71 72 private: 73 int categories_; 74 int negative_categories_; 75 }; 76 77 // Individual ops should define an Init() function that returns a 78 // BlacklistOpBase. 79 80 void Free(TfLiteContext* context, void* buffer); 81 82 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node); 83 84 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node); 85 } // namespace blacklist 86 } // namespace libtextclassifier3 87 } // namespace custom 88 } // namespace ops 89 } // namespace tflite 90 91 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_BLACKLIST_BASE_H_ 92