1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_ 19 20 /** 21 * String projection op used in Self-Governing Neural Network (SGNN) 22 * and other ProjectionNet models for text prediction. 23 * The code is copied/adapted from 24 * learning/expander/pod/deep_pod/tflite_handlers/ 25 */ 26 27 #include <string> 28 #include <unordered_map> 29 #include <vector> 30 31 #include "flatbuffers/flexbuffers.h" 32 #include "tensorflow/lite/context.h" 33 34 namespace tflite { 35 namespace ops { 36 namespace custom { 37 namespace libtextclassifier3 { 38 namespace string_projection { 39 40 struct SkipGramParams { 41 // Num of tokens in ngram. 42 int ngram_size; 43 44 // Max num of tokens to skip in skip gram. 45 int max_skip_size; 46 47 // True when include all k-grams where k <= ngram_size. 48 bool include_all_ngrams; 49 50 // True when include preprocess. 51 bool preprocess; 52 53 // True when tokens are chars, false when tokens are whitespace separated. 54 bool char_level; 55 56 // True when punctuations are removed. 57 bool remove_punctuation; 58 59 // Max num of chars to process in input. 60 int max_input_chars; 61 }; 62 63 /** 64 * A framework for writing TFLite ops that convert strings to integers via LSH 65 * projections. Input is defined by the specific implementation. 66 * NOTE: Only supports dense projection. 67 * 68 * Attributes: 69 * num_hash: int[] 70 * number of hash functions 71 * num_bits: int[] 72 * number of bits in each hash function 73 * hash_function: float[num_hash * num_bits] 74 * hash_functions used to generate projections 75 * ngram_size: int[] 76 * maximum number of tokens in skipgrams 77 * max_skip_size: int[] 78 * maximum number of tokens to skip between tokens in skipgrams. 79 * include_all_ngrams: bool[] 80 * if false, only use skipgrams with ngram_size tokens 81 * preprocess: bool[] 82 * if true, normalize input strings (lower case, remove punctuation) 83 * hash_method: string[] 84 * hashing function to use 85 * char_level: bool[] 86 * if true, treat each character as a token 87 * binary_projection: bool[] 88 * if true, output features are 0 or 1 89 * remove_punctuation: bool[] 90 * if true, remove punctuation during normalization/preprocessing 91 * 92 * Output: 93 * tensor[0]: computed projections. float32[..., num_func * num_bits] 94 */ 95 96 class StringProjectionOpBase { 97 public: 98 explicit StringProjectionOpBase(const flexbuffers::Map& custom_options); 99 ~StringProjectionOpBase()100 virtual ~StringProjectionOpBase() {} 101 102 void GetFeatureWeights( 103 const std::unordered_map<std::string, int>& feature_counts, 104 std::vector<std::vector<int64_t>>* batch_ids, 105 std::vector<std::vector<float>>* batch_weights); 106 107 void DenseLshProjection(const int batch_size, 108 const std::vector<std::vector<int64_t>>& batch_ids, 109 const std::vector<std::vector<float>>& batch_weights, 110 TfLiteTensor* output); 111 num_hash()112 inline int num_hash() { return num_hash_; } num_bits()113 inline int num_bits() { return num_bits_; } 114 virtual TfLiteStatus InitializeInput(TfLiteContext* context, 115 TfLiteNode* node) = 0; 116 virtual std::unordered_map<std::string, int> ExtractSkipGrams(int i) = 0; 117 virtual void FinalizeInput() = 0; 118 119 // Returns the input shape. TfLiteIntArray is owned by the object. 120 virtual TfLiteIntArray* GetInputShape(TfLiteContext* context, 121 TfLiteNode* node) = 0; 122 123 protected: skip_gram_params()124 SkipGramParams& skip_gram_params() { return skip_gram_params_; } 125 126 private: 127 ::flexbuffers::TypedVector hash_function_; 128 int num_hash_; 129 int num_bits_; 130 bool binary_projection_; 131 std::string hash_method_; 132 float axb_scale_; 133 SkipGramParams skip_gram_params_; 134 135 // Compute sign bit of dot product of hash(seed, input) and weight. 136 float running_sign_bit(const std::vector<int64_t>& input, 137 const std::vector<float>& weight, float seed, 138 char* key); 139 }; 140 141 // Individual ops should define an Init() function that returns a 142 // BlacklistOpBase. 143 144 void Free(TfLiteContext* context, void* buffer); 145 146 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node); 147 148 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node); 149 150 } // namespace string_projection 151 } // namespace libtextclassifier3 152 } // namespace custom 153 } // namespace ops 154 } // namespace tflite 155 156 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_ 157