• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "common/embedding-network-params.h"
24 #include "common/feature-extractor.h"
25 #include "common/vector-span.h"
26 #include "util/base/integral_types.h"
27 #include "util/base/logging.h"
28 #include "util/base/macros.h"
29 
30 namespace libtextclassifier {
31 namespace nlp_core {
32 
33 // Classifier using a hand-coded feed-forward neural network.
34 //
35 // No gradient computation, just inference.
36 //
37 // Classification works as follows:
38 //
39 // Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
40 //
41 // In words: given some discrete features, this class extracts the embeddings
42 // for these features, concatenates them, passes them through one or two hidden
43 // layers (each layer uses Relu) and next through a softmax layer that computes
44 // an unnormalized score for each possible class.  Note: there is always a
45 // softmax layer.
46 class EmbeddingNetwork {
47  public:
48   // Class used to represent an embedding matrix.  Each row is the embedding on
49   // a vocabulary element.  Number of columns = number of embedding dimensions.
50   class EmbeddingMatrix {
51    public:
EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix)52     explicit EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix)
53         : rows_(source_matrix.rows),
54           cols_(source_matrix.cols),
55           quant_type_(source_matrix.quant_type),
56           data_(source_matrix.elements),
57           row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)),
58           quant_scales_(source_matrix.quant_scales) {}
59 
60     // Returns vocabulary size; one embedding for each vocabulary element.
size()61     int size() const { return rows_; }
62 
63     // Returns number of weights in embedding of each vocabulary element.
dim()64     int dim() const { return cols_; }
65 
66     // Returns quantization type for this embedding matrix.
quant_type()67     QuantizationType quant_type() const { return quant_type_; }
68 
69     // Gets embedding for k-th vocabulary element: on return, sets *data to
70     // point to the embedding weights and *scale to the quantization scale (1.0
71     // if no quantization).
get_embedding(int k,const void ** data,float * scale)72     void get_embedding(int k, const void **data, float *scale) const {
73       if ((k < 0) || (k >= size())) {
74         TC_LOG(ERROR) << "Index outside [0, " << size() << "): " << k;
75 
76         // In debug mode, crash.  In prod, pretend that k is 0.
77         TC_DCHECK(false);
78         k = 0;
79       }
80       *data = reinterpret_cast<const char *>(data_) + k * row_size_in_bytes_;
81       if (quant_type_ == QuantizationType::NONE) {
82         *scale = 1.0;
83       } else {
84         *scale = Float16To32(quant_scales_[k]);
85       }
86     }
87 
88    private:
GetRowSizeInBytes(int cols,QuantizationType quant_type)89     static int GetRowSizeInBytes(int cols, QuantizationType quant_type) {
90       switch (quant_type) {
91         case QuantizationType::NONE:
92           return cols * sizeof(float);
93         case QuantizationType::UINT8:
94           return cols * sizeof(uint8);
95         default:
96           TC_LOG(ERROR) << "Unknown quant type: "
97                         << static_cast<int>(quant_type);
98           return 0;
99       }
100     }
101 
102     // Vocabulary size.
103     const int rows_;
104 
105     // Number of elements in each embedding.
106     const int cols_;
107 
108     const QuantizationType quant_type_;
109 
110     // Pointer to the embedding weights, in row-major order.  This is a pointer
111     // to an array of floats / uint8, depending on the quantization type.
112     // Not owned.
113     const void *const data_;
114 
115     // Number of bytes for one row.  Used to jump to next row in data_.
116     const int row_size_in_bytes_;
117 
118     // Pointer to quantization scales.  nullptr if no quantization.  Otherwise,
119     // quant_scales_[i] is scale for embedding of i-th vocabulary element.
120     const float16 *const quant_scales_;
121 
122     TC_DISALLOW_COPY_AND_ASSIGN(EmbeddingMatrix);
123   };
124 
125   // An immutable vector that doesn't own the memory that stores the underlying
126   // floats.  Can be used e.g., as a wrapper around model weights stored in the
127   // static memory.
128   class VectorWrapper {
129    public:
VectorWrapper()130     VectorWrapper() : VectorWrapper(nullptr, 0) {}
131 
132     // Constructs a vector wrapper around the size consecutive floats that start
133     // at address data.  Note: the underlying data should be alive for at least
134     // the lifetime of this VectorWrapper object.  That's trivially true if data
135     // points to statically allocated data :)
VectorWrapper(const float * data,int size)136     VectorWrapper(const float *data, int size) : data_(data), size_(size) {}
137 
size()138     int size() const { return size_; }
139 
data()140     const float *data() const { return data_; }
141 
142    private:
143     const float *data_;  // Not owned.
144     int size_;
145 
146     // Doesn't own anything, so it can be copied and assigned at will :)
147   };
148 
149   typedef std::vector<VectorWrapper> Matrix;
150   typedef std::vector<float> Vector;
151 
152   // Constructs an embedding network using the parameters from model.
153   //
154   // Note: model should stay alive for at least the lifetime of this
155   // EmbeddingNetwork object.
156   explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
157 
~EmbeddingNetwork()158   virtual ~EmbeddingNetwork() {}
159 
160   // Returns true if this EmbeddingNetwork object has been correctly constructed
161   // and is ready to use.  Idea: in case of errors, mark this EmbeddingNetwork
162   // object as invalid, but do not crash.
is_valid()163   bool is_valid() const { return valid_; }
164 
165   // Runs forward computation to fill scores with unnormalized output unit
166   // scores. This is useful for making predictions.
167   //
168   // Returns true on success, false on error (e.g., if !is_valid()).
169   bool ComputeFinalScores(const std::vector<FeatureVector> &features,
170                           Vector *scores) const;
171 
172   // Same as above, but allows specification of extra neural network inputs that
173   // will be appended to the embedding vector build from features.
174   bool ComputeFinalScores(const std::vector<FeatureVector> &features,
175                           const std::vector<float> extra_inputs,
176                           Vector *scores) const;
177 
178   // Constructs the concatenated input embedding vector in place in output
179   // vector concat.  Returns true on success, false on error.
180   bool ConcatEmbeddings(const std::vector<FeatureVector> &features,
181                         Vector *concat) const;
182 
183   // Sums embeddings for all features from |feature_vector| and adds result
184   // to values from the array pointed-to by |output|.  Embeddings for continuous
185   // features are weighted by the feature weight.
186   //
187   // NOTE: output should point to an array of EmbeddingSize(es_index) floats.
188   bool GetEmbedding(const FeatureVector &feature_vector, int es_index,
189                     float *embedding) const;
190 
191   // Runs the feed-forward neural network for |input| and computes logits for
192   // softmax layer.
193   bool ComputeLogits(const Vector &input, Vector *scores) const;
194 
195   // Same as above but uses a view of the feature vector.
196   bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const;
197 
198   // Returns the size (the number of columns) of the embedding space es_index.
199   int EmbeddingSize(int es_index) const;
200 
201  protected:
202   // Builds an embedding for given feature vector, and places it from
203   // concat_offset to the concat vector.
204   bool GetEmbeddingInternal(const FeatureVector &feature_vector,
205                             EmbeddingMatrix *embedding_matrix,
206                             int concat_offset, float *concat,
207                             int embedding_size) const;
208 
209   // Templated function that computes the logit scores given the concatenated
210   // input embeddings.
211   bool ComputeLogitsInternal(const VectorSpan<float> &concat,
212                              Vector *scores) const;
213 
214   // Computes the softmax scores (prior to normalization) from the concatenated
215   // representation.  Returns true on success, false on error.
216   template <typename ScaleAdderClass>
217   bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat,
218                                         Vector *scores) const;
219 
220   // Set to true on successful construction, false otherwise.
221   bool valid_ = false;
222 
223   // Network parameters.
224 
225   // One weight matrix for each embedding space.
226   std::vector<std::unique_ptr<EmbeddingMatrix>> embedding_matrices_;
227 
228   // concat_offset_[i] is the input layer offset for i-th embedding space.
229   std::vector<int> concat_offset_;
230 
231   // Size of the input ("concatenation") layer.
232   int concat_layer_size_;
233 
234   // One weight matrix and one vector of bias weights for each hiden layer.
235   std::vector<Matrix> hidden_weights_;
236   std::vector<VectorWrapper> hidden_bias_;
237 
238   // Weight matrix and bias vector for the softmax layer.
239   Matrix softmax_weights_;
240   VectorWrapper softmax_bias_;
241 };
242 
243 }  // namespace nlp_core
244 }  // namespace libtextclassifier
245 
246 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
247