1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_ 18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_ 19 20 #include <memory> 21 #include <vector> 22 23 #include "common/embedding-network-params.h" 24 #include "common/feature-extractor.h" 25 #include "common/vector-span.h" 26 #include "util/base/integral_types.h" 27 #include "util/base/logging.h" 28 #include "util/base/macros.h" 29 30 namespace libtextclassifier { 31 namespace nlp_core { 32 33 // Classifier using a hand-coded feed-forward neural network. 34 // 35 // No gradient computation, just inference. 36 // 37 // Classification works as follows: 38 // 39 // Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax 40 // 41 // In words: given some discrete features, this class extracts the embeddings 42 // for these features, concatenates them, passes them through one or two hidden 43 // layers (each layer uses Relu) and next through a softmax layer that computes 44 // an unnormalized score for each possible class. Note: there is always a 45 // softmax layer. 46 class EmbeddingNetwork { 47 public: 48 // Class used to represent an embedding matrix. Each row is the embedding on 49 // a vocabulary element. Number of columns = number of embedding dimensions. 50 class EmbeddingMatrix { 51 public: EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix)52 explicit EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix) 53 : rows_(source_matrix.rows), 54 cols_(source_matrix.cols), 55 quant_type_(source_matrix.quant_type), 56 data_(source_matrix.elements), 57 row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)), 58 quant_scales_(source_matrix.quant_scales) {} 59 60 // Returns vocabulary size; one embedding for each vocabulary element. size()61 int size() const { return rows_; } 62 63 // Returns number of weights in embedding of each vocabulary element. dim()64 int dim() const { return cols_; } 65 66 // Returns quantization type for this embedding matrix. quant_type()67 QuantizationType quant_type() const { return quant_type_; } 68 69 // Gets embedding for k-th vocabulary element: on return, sets *data to 70 // point to the embedding weights and *scale to the quantization scale (1.0 71 // if no quantization). get_embedding(int k,const void ** data,float * scale)72 void get_embedding(int k, const void **data, float *scale) const { 73 if ((k < 0) || (k >= size())) { 74 TC_LOG(ERROR) << "Index outside [0, " << size() << "): " << k; 75 76 // In debug mode, crash. In prod, pretend that k is 0. 77 TC_DCHECK(false); 78 k = 0; 79 } 80 *data = reinterpret_cast<const char *>(data_) + k * row_size_in_bytes_; 81 if (quant_type_ == QuantizationType::NONE) { 82 *scale = 1.0; 83 } else { 84 *scale = Float16To32(quant_scales_[k]); 85 } 86 } 87 88 private: GetRowSizeInBytes(int cols,QuantizationType quant_type)89 static int GetRowSizeInBytes(int cols, QuantizationType quant_type) { 90 switch (quant_type) { 91 case QuantizationType::NONE: 92 return cols * sizeof(float); 93 case QuantizationType::UINT8: 94 return cols * sizeof(uint8); 95 default: 96 TC_LOG(ERROR) << "Unknown quant type: " 97 << static_cast<int>(quant_type); 98 return 0; 99 } 100 } 101 102 // Vocabulary size. 103 const int rows_; 104 105 // Number of elements in each embedding. 106 const int cols_; 107 108 const QuantizationType quant_type_; 109 110 // Pointer to the embedding weights, in row-major order. This is a pointer 111 // to an array of floats / uint8, depending on the quantization type. 112 // Not owned. 113 const void *const data_; 114 115 // Number of bytes for one row. Used to jump to next row in data_. 116 const int row_size_in_bytes_; 117 118 // Pointer to quantization scales. nullptr if no quantization. Otherwise, 119 // quant_scales_[i] is scale for embedding of i-th vocabulary element. 120 const float16 *const quant_scales_; 121 122 TC_DISALLOW_COPY_AND_ASSIGN(EmbeddingMatrix); 123 }; 124 125 // An immutable vector that doesn't own the memory that stores the underlying 126 // floats. Can be used e.g., as a wrapper around model weights stored in the 127 // static memory. 128 class VectorWrapper { 129 public: VectorWrapper()130 VectorWrapper() : VectorWrapper(nullptr, 0) {} 131 132 // Constructs a vector wrapper around the size consecutive floats that start 133 // at address data. Note: the underlying data should be alive for at least 134 // the lifetime of this VectorWrapper object. That's trivially true if data 135 // points to statically allocated data :) VectorWrapper(const float * data,int size)136 VectorWrapper(const float *data, int size) : data_(data), size_(size) {} 137 size()138 int size() const { return size_; } 139 data()140 const float *data() const { return data_; } 141 142 private: 143 const float *data_; // Not owned. 144 int size_; 145 146 // Doesn't own anything, so it can be copied and assigned at will :) 147 }; 148 149 typedef std::vector<VectorWrapper> Matrix; 150 typedef std::vector<float> Vector; 151 152 // Constructs an embedding network using the parameters from model. 153 // 154 // Note: model should stay alive for at least the lifetime of this 155 // EmbeddingNetwork object. 156 explicit EmbeddingNetwork(const EmbeddingNetworkParams *model); 157 ~EmbeddingNetwork()158 virtual ~EmbeddingNetwork() {} 159 160 // Returns true if this EmbeddingNetwork object has been correctly constructed 161 // and is ready to use. Idea: in case of errors, mark this EmbeddingNetwork 162 // object as invalid, but do not crash. is_valid()163 bool is_valid() const { return valid_; } 164 165 // Runs forward computation to fill scores with unnormalized output unit 166 // scores. This is useful for making predictions. 167 // 168 // Returns true on success, false on error (e.g., if !is_valid()). 169 bool ComputeFinalScores(const std::vector<FeatureVector> &features, 170 Vector *scores) const; 171 172 // Same as above, but allows specification of extra neural network inputs that 173 // will be appended to the embedding vector build from features. 174 bool ComputeFinalScores(const std::vector<FeatureVector> &features, 175 const std::vector<float> extra_inputs, 176 Vector *scores) const; 177 178 // Constructs the concatenated input embedding vector in place in output 179 // vector concat. Returns true on success, false on error. 180 bool ConcatEmbeddings(const std::vector<FeatureVector> &features, 181 Vector *concat) const; 182 183 // Sums embeddings for all features from |feature_vector| and adds result 184 // to values from the array pointed-to by |output|. Embeddings for continuous 185 // features are weighted by the feature weight. 186 // 187 // NOTE: output should point to an array of EmbeddingSize(es_index) floats. 188 bool GetEmbedding(const FeatureVector &feature_vector, int es_index, 189 float *embedding) const; 190 191 // Runs the feed-forward neural network for |input| and computes logits for 192 // softmax layer. 193 bool ComputeLogits(const Vector &input, Vector *scores) const; 194 195 // Same as above but uses a view of the feature vector. 196 bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const; 197 198 // Returns the size (the number of columns) of the embedding space es_index. 199 int EmbeddingSize(int es_index) const; 200 201 protected: 202 // Builds an embedding for given feature vector, and places it from 203 // concat_offset to the concat vector. 204 bool GetEmbeddingInternal(const FeatureVector &feature_vector, 205 EmbeddingMatrix *embedding_matrix, 206 int concat_offset, float *concat, 207 int embedding_size) const; 208 209 // Templated function that computes the logit scores given the concatenated 210 // input embeddings. 211 bool ComputeLogitsInternal(const VectorSpan<float> &concat, 212 Vector *scores) const; 213 214 // Computes the softmax scores (prior to normalization) from the concatenated 215 // representation. Returns true on success, false on error. 216 template <typename ScaleAdderClass> 217 bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat, 218 Vector *scores) const; 219 220 // Set to true on successful construction, false otherwise. 221 bool valid_ = false; 222 223 // Network parameters. 224 225 // One weight matrix for each embedding space. 226 std::vector<std::unique_ptr<EmbeddingMatrix>> embedding_matrices_; 227 228 // concat_offset_[i] is the input layer offset for i-th embedding space. 229 std::vector<int> concat_offset_; 230 231 // Size of the input ("concatenation") layer. 232 int concat_layer_size_; 233 234 // One weight matrix and one vector of bias weights for each hiden layer. 235 std::vector<Matrix> hidden_weights_; 236 std::vector<VectorWrapper> hidden_bias_; 237 238 // Weight matrix and bias vector for the softmax layer. 239 Matrix softmax_weights_; 240 VectorWrapper softmax_bias_; 241 }; 242 243 } // namespace nlp_core 244 } // namespace libtextclassifier 245 246 #endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_ 247