• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_
18 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_
19 
20 /**
21  * String projection op used in Self-Governing Neural Network (SGNN)
22  * and other ProjectionNet models for text prediction.
23  * The code is copied/adapted from
24  * learning/expander/pod/deep_pod/tflite_handlers/
25  */
26 
27 #include <string>
28 #include <unordered_map>
29 #include <vector>
30 
31 #include "flatbuffers/flexbuffers.h"
32 #include "tensorflow/lite/context.h"
33 
34 namespace tflite {
35 namespace ops {
36 namespace custom {
37 namespace libtextclassifier3 {
38 namespace string_projection {
39 
40 struct SkipGramParams {
41   // Num of tokens in ngram.
42   int ngram_size;
43 
44   // Max num of tokens to skip in skip gram.
45   int max_skip_size;
46 
47   // True when include all k-grams where k <= ngram_size.
48   bool include_all_ngrams;
49 
50   // True when include preprocess.
51   bool preprocess;
52 
53   // True when tokens are chars, false when tokens are whitespace separated.
54   bool char_level;
55 
56   // True when punctuations are removed.
57   bool remove_punctuation;
58 
59   // Max num of chars to process in input.
60   int max_input_chars;
61 };
62 
63 /**
64  * A framework for writing TFLite ops that convert strings to integers via LSH
65  * projections.  Input is defined by the specific implementation.
66  * NOTE: Only supports dense projection.
67  *
68  * Attributes:
69  *   num_hash:           int[]
70  *     number of hash functions
71  *   num_bits:           int[]
72  *     number of bits in each hash function
73  *   hash_function:      float[num_hash * num_bits]
74  *     hash_functions used to generate projections
75  *   ngram_size:         int[]
76  *     maximum number of tokens in skipgrams
77  *   max_skip_size:      int[]
78  *     maximum number of tokens to skip between tokens in skipgrams.
79  *   include_all_ngrams: bool[]
80  *     if false, only use skipgrams with ngram_size tokens
81  *   preprocess:         bool[]
82  *     if true, normalize input strings (lower case, remove punctuation)
83  *   hash_method:        string[]
84  *     hashing function to use
85  *   char_level:         bool[]
86  *     if true, treat each character as a token
87  *   binary_projection:  bool[]
88  *     if true, output features are 0 or 1
89  *   remove_punctuation: bool[]
90  *     if true, remove punctuation during normalization/preprocessing
91  *
92  * Output:
93  *   tensor[0]: computed projections. float32[..., num_func * num_bits]
94  */
95 
96 class StringProjectionOpBase {
97  public:
98   explicit StringProjectionOpBase(const flexbuffers::Map& custom_options);
99 
~StringProjectionOpBase()100   virtual ~StringProjectionOpBase() {}
101 
102   void GetFeatureWeights(
103       const std::unordered_map<std::string, int>& feature_counts,
104       std::vector<std::vector<int64_t>>* batch_ids,
105       std::vector<std::vector<float>>* batch_weights);
106 
107   void DenseLshProjection(const int batch_size,
108                           const std::vector<std::vector<int64_t>>& batch_ids,
109                           const std::vector<std::vector<float>>& batch_weights,
110                           TfLiteTensor* output);
111 
num_hash()112   inline int num_hash() { return num_hash_; }
num_bits()113   inline int num_bits() { return num_bits_; }
114   virtual TfLiteStatus InitializeInput(TfLiteContext* context,
115                                        TfLiteNode* node) = 0;
116   virtual std::unordered_map<std::string, int> ExtractSkipGrams(int i) = 0;
117   virtual void FinalizeInput() = 0;
118 
119   // Returns the input shape.  TfLiteIntArray is owned by the object.
120   virtual TfLiteIntArray* GetInputShape(TfLiteContext* context,
121                                         TfLiteNode* node) = 0;
122 
123  protected:
skip_gram_params()124   SkipGramParams& skip_gram_params() { return skip_gram_params_; }
125 
126  private:
127   ::flexbuffers::TypedVector hash_function_;
128   int num_hash_;
129   int num_bits_;
130   bool binary_projection_;
131   std::string hash_method_;
132   float axb_scale_;
133   SkipGramParams skip_gram_params_;
134 
135   // Compute sign bit of dot product of hash(seed, input) and weight.
136   float running_sign_bit(const std::vector<int64_t>& input,
137                          const std::vector<float>& weight, float seed,
138                          char* key);
139 };
140 
141 // Individual ops should define an Init() function that returns a
142 // BlacklistOpBase.
143 
144 void Free(TfLiteContext* context, void* buffer);
145 
146 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node);
147 
148 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node);
149 
150 }  // namespace string_projection
151 }  // namespace libtextclassifier3
152 }  // namespace custom
153 }  // namespace ops
154 }  // namespace tflite
155 
156 #endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_STRING_PROJECTION_BASE_H_
157