• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/embedding-network.h"
18 
19 #include <math.h>
20 
21 #include "common/simple-adder.h"
22 #include "util/base/integral_types.h"
23 #include "util/base/logging.h"
24 
25 namespace libtextclassifier {
26 namespace nlp_core {
27 
28 namespace {
29 
30 // Returns true if and only if matrix does not use any quantization.
CheckNoQuantization(const EmbeddingNetworkParams::Matrix & matrix)31 bool CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
32   if (matrix.quant_type != QuantizationType::NONE) {
33     TC_LOG(ERROR) << "Unsupported quantization";
34     TC_DCHECK(false);  // Crash in debug mode.
35     return false;
36   }
37   return true;
38 }
39 
40 // Initializes a Matrix object with the parameters from the MatrixParams
41 // source_matrix.  source_matrix should not use quantization.
42 //
43 // Returns true on success, false on error.
InitNonQuantizedMatrix(const EmbeddingNetworkParams::Matrix & source_matrix,EmbeddingNetwork::Matrix * mat)44 bool InitNonQuantizedMatrix(const EmbeddingNetworkParams::Matrix &source_matrix,
45                             EmbeddingNetwork::Matrix *mat) {
46   mat->resize(source_matrix.rows);
47 
48   // Before we access the weights as floats, we need to check that they are
49   // really floats, i.e., no quantization is used.
50   if (!CheckNoQuantization(source_matrix)) return false;
51   const float *weights =
52       reinterpret_cast<const float *>(source_matrix.elements);
53   for (int r = 0; r < source_matrix.rows; ++r) {
54     (*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols);
55     weights += source_matrix.cols;
56   }
57   return true;
58 }
59 
60 // Initializes a VectorWrapper object with the parameters from the MatrixParams
61 // source_matrix.  source_matrix should have exactly one column and should not
62 // use quantization.
63 //
64 // Returns true on success, false on error.
InitNonQuantizedVector(const EmbeddingNetworkParams::Matrix & source_matrix,EmbeddingNetwork::VectorWrapper * vector)65 bool InitNonQuantizedVector(const EmbeddingNetworkParams::Matrix &source_matrix,
66                             EmbeddingNetwork::VectorWrapper *vector) {
67   if (source_matrix.cols != 1) {
68     TC_LOG(ERROR) << "wrong #cols " << source_matrix.cols;
69     return false;
70   }
71   if (!CheckNoQuantization(source_matrix)) {
72     TC_LOG(ERROR) << "unsupported quantization";
73     return false;
74   }
75   // Before we access the weights as floats, we need to check that they are
76   // really floats, i.e., no quantization is used.
77   if (!CheckNoQuantization(source_matrix)) return false;
78   const float *weights =
79       reinterpret_cast<const float *>(source_matrix.elements);
80   *vector = EmbeddingNetwork::VectorWrapper(weights, source_matrix.rows);
81   return true;
82 }
83 
84 // Computes y = weights * Relu(x) + b where Relu is optionally applied.
85 template <typename ScaleAdderClass>
SparseReluProductPlusBias(bool apply_relu,const EmbeddingNetwork::Matrix & weights,const EmbeddingNetwork::VectorWrapper & b,const VectorSpan<float> & x,EmbeddingNetwork::Vector * y)86 bool SparseReluProductPlusBias(bool apply_relu,
87                                const EmbeddingNetwork::Matrix &weights,
88                                const EmbeddingNetwork::VectorWrapper &b,
89                                const VectorSpan<float> &x,
90                                EmbeddingNetwork::Vector *y) {
91   // Check that dimensions match.
92   if ((x.size() != weights.size()) || weights.empty()) {
93     TC_LOG(ERROR) << x.size() << " != " << weights.size();
94     return false;
95   }
96   if (weights[0].size() != b.size()) {
97     TC_LOG(ERROR) << weights[0].size() << " != " << b.size();
98     return false;
99   }
100 
101   y->assign(b.data(), b.data() + b.size());
102   ScaleAdderClass adder(y->data(), y->size());
103 
104   const int x_size = x.size();
105   for (int i = 0; i < x_size; ++i) {
106     const float &scale = x[i];
107     if (apply_relu) {
108       if (scale > 0) {
109         adder.LazyScaleAdd(weights[i].data(), scale);
110       }
111     } else {
112       adder.LazyScaleAdd(weights[i].data(), scale);
113     }
114   }
115   return true;
116 }
117 }  // namespace
118 
ConcatEmbeddings(const std::vector<FeatureVector> & feature_vectors,Vector * concat) const119 bool EmbeddingNetwork::ConcatEmbeddings(
120     const std::vector<FeatureVector> &feature_vectors, Vector *concat) const {
121   concat->resize(concat_layer_size_);
122 
123   // Invariant 1: feature_vectors contains exactly one element for each
124   // embedding space.  That element is itself a FeatureVector, which may be
125   // empty, but it should be there.
126   if (feature_vectors.size() != embedding_matrices_.size()) {
127     TC_LOG(ERROR) << feature_vectors.size()
128                   << " != " << embedding_matrices_.size();
129     return false;
130   }
131 
132   // "es_index" stands for "embedding space index".
133   for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
134     // Access is safe by es_index loop bounds and Invariant 1.
135     EmbeddingMatrix *const embedding_matrix =
136         embedding_matrices_[es_index].get();
137     if (embedding_matrix == nullptr) {
138       // Should not happen, hence our terse log error message.
139       TC_LOG(ERROR) << es_index;
140       return false;
141     }
142 
143     // Access is safe due to es_index loop bounds.
144     const FeatureVector &feature_vector = feature_vectors[es_index];
145 
146     // Access is safe by es_index loop bounds, Invariant 1, and Invariant 2.
147     const int concat_offset = concat_offset_[es_index];
148 
149     if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset,
150                               concat->data(), concat->size())) {
151       TC_LOG(ERROR) << es_index;
152       return false;
153     }
154   }
155   return true;
156 }
157 
GetEmbedding(const FeatureVector & feature_vector,int es_index,float * embedding) const158 bool EmbeddingNetwork::GetEmbedding(const FeatureVector &feature_vector,
159                                     int es_index, float *embedding) const {
160   EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get();
161   if (embedding_matrix == nullptr) {
162     // Should not happen, hence our terse log error message.
163     TC_LOG(ERROR) << es_index;
164     return false;
165   }
166   return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding,
167                               embedding_matrices_[es_index]->dim());
168 }
169 
GetEmbeddingInternal(const FeatureVector & feature_vector,EmbeddingMatrix * const embedding_matrix,const int concat_offset,float * concat,int concat_size) const170 bool EmbeddingNetwork::GetEmbeddingInternal(
171     const FeatureVector &feature_vector,
172     EmbeddingMatrix *const embedding_matrix, const int concat_offset,
173     float *concat, int concat_size) const {
174   const int embedding_dim = embedding_matrix->dim();
175   const bool is_quantized =
176       embedding_matrix->quant_type() != QuantizationType::NONE;
177   const int num_features = feature_vector.size();
178   for (int fi = 0; fi < num_features; ++fi) {
179     // Both accesses below are safe due to loop bounds for fi.
180     const FeatureType *feature_type = feature_vector.type(fi);
181     const FeatureValue feature_value = feature_vector.value(fi);
182     const int feature_offset =
183         concat_offset + feature_type->base() * embedding_dim;
184 
185     // Code below updates max(0, embedding_dim) elements from concat, starting
186     // with index feature_offset.  Check below ensures these updates are safe.
187     if ((feature_offset < 0) ||
188         (feature_offset + embedding_dim > concat_size)) {
189       TC_LOG(ERROR) << fi << ": " << feature_offset << " " << embedding_dim
190                     << " " << concat_size;
191       return false;
192     }
193 
194     // Pointer to float / uint8 weights for relevant embedding.
195     const void *embedding_data;
196 
197     // Multiplier for each embedding weight.
198     float multiplier;
199 
200     if (feature_type->is_continuous()) {
201       // Continuous features (encoded as FloatFeatureValue).
202       FloatFeatureValue float_feature_value(feature_value);
203       const int id = float_feature_value.id;
204       embedding_matrix->get_embedding(id, &embedding_data, &multiplier);
205       multiplier *= float_feature_value.weight;
206     } else {
207       // Discrete features: every present feature has implicit value 1.0.
208       // Hence, after we grab the multiplier below, we don't multiply it by
209       // any weight.
210       embedding_matrix->get_embedding(feature_value, &embedding_data,
211                                       &multiplier);
212     }
213 
214     // Weighted embeddings will be added starting from this address.
215     float *concat_ptr = concat + feature_offset;
216 
217     if (is_quantized) {
218       const uint8 *quant_weights =
219           reinterpret_cast<const uint8 *>(embedding_data);
220       for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
221         // 128 is bias for UINT8 quantization, only one we currently support.
222         *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
223       }
224     } else {
225       const float *weights = reinterpret_cast<const float *>(embedding_data);
226       for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
227         *concat_ptr += *weights * multiplier;
228       }
229     }
230   }
231   return true;
232 }
233 
ComputeLogits(const VectorSpan<float> & input,Vector * scores) const234 bool EmbeddingNetwork::ComputeLogits(const VectorSpan<float> &input,
235                                      Vector *scores) const {
236   return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
237 }
238 
ComputeLogits(const Vector & input,Vector * scores) const239 bool EmbeddingNetwork::ComputeLogits(const Vector &input,
240                                      Vector *scores) const {
241   return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
242 }
243 
ComputeLogitsInternal(const VectorSpan<float> & input,Vector * scores) const244 bool EmbeddingNetwork::ComputeLogitsInternal(const VectorSpan<float> &input,
245                                              Vector *scores) const {
246   return FinishComputeFinalScoresInternal<SimpleAdder>(input, scores);
247 }
248 
249 template <typename ScaleAdderClass>
FinishComputeFinalScoresInternal(const VectorSpan<float> & input,Vector * scores) const250 bool EmbeddingNetwork::FinishComputeFinalScoresInternal(
251     const VectorSpan<float> &input, Vector *scores) const {
252   // This vector serves as an alternating storage for activations of the
253   // different layers. We can't use just one vector here because all of the
254   // activations of  the previous layer are needed for computation of
255   // activations of the next one.
256   std::vector<Vector> h_storage(2);
257 
258   // Compute pre-logits activations.
259   VectorSpan<float> h_in(input);
260   Vector *h_out;
261   for (int i = 0; i < hidden_weights_.size(); ++i) {
262     const bool apply_relu = i > 0;
263     h_out = &(h_storage[i % 2]);
264     h_out->resize(hidden_bias_[i].size());
265     if (!SparseReluProductPlusBias<ScaleAdderClass>(
266             apply_relu, hidden_weights_[i], hidden_bias_[i], h_in, h_out)) {
267       return false;
268     }
269     h_in = VectorSpan<float>(*h_out);
270   }
271 
272   // Compute logit scores.
273   if (!SparseReluProductPlusBias<ScaleAdderClass>(
274           true, softmax_weights_, softmax_bias_, h_in, scores)) {
275     return false;
276   }
277 
278   return true;
279 }
280 
ComputeFinalScores(const std::vector<FeatureVector> & features,Vector * scores) const281 bool EmbeddingNetwork::ComputeFinalScores(
282     const std::vector<FeatureVector> &features, Vector *scores) const {
283   return ComputeFinalScores(features, {}, scores);
284 }
285 
ComputeFinalScores(const std::vector<FeatureVector> & features,const std::vector<float> extra_inputs,Vector * scores) const286 bool EmbeddingNetwork::ComputeFinalScores(
287     const std::vector<FeatureVector> &features,
288     const std::vector<float> extra_inputs, Vector *scores) const {
289   // If we haven't successfully initialized, return without doing anything.
290   if (!is_valid()) return false;
291 
292   Vector concat;
293   if (!ConcatEmbeddings(features, &concat)) return false;
294 
295   if (!extra_inputs.empty()) {
296     concat.reserve(concat.size() + extra_inputs.size());
297     for (int i = 0; i < extra_inputs.size(); i++) {
298       concat.push_back(extra_inputs[i]);
299     }
300   }
301 
302   scores->resize(softmax_bias_.size());
303   return ComputeLogits(concat, scores);
304 }
305 
EmbeddingNetwork(const EmbeddingNetworkParams * model)306 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) {
307   // We'll set valid_ to true only if construction is successful.  If we detect
308   // an error along the way, we log an informative message and return early, but
309   // we do not crash.
310   valid_ = false;
311 
312   // Fill embedding_matrices_, concat_offset_, and concat_layer_size_.
313   const int num_embedding_spaces = model->GetNumEmbeddingSpaces();
314   int offset_sum = 0;
315   for (int i = 0; i < num_embedding_spaces; ++i) {
316     concat_offset_.push_back(offset_sum);
317     const EmbeddingNetworkParams::Matrix matrix = model->GetEmbeddingMatrix(i);
318     if (matrix.quant_type != QuantizationType::UINT8) {
319       TC_LOG(ERROR) << "Unsupported quantization for embedding #" << i << ": "
320                     << static_cast<int>(matrix.quant_type);
321       return;
322     }
323 
324     // There is no way to accomodate an empty embedding matrix.  E.g., there is
325     // no way for get_embedding to return something that can be read safely.
326     // Hence, we catch that error here and return early.
327     if (matrix.rows == 0) {
328       TC_LOG(ERROR) << "Empty embedding matrix #" << i;
329       return;
330     }
331     embedding_matrices_.emplace_back(new EmbeddingMatrix(matrix));
332     const int embedding_dim = embedding_matrices_.back()->dim();
333     offset_sum += embedding_dim * model->GetNumFeaturesInEmbeddingSpace(i);
334   }
335   concat_layer_size_ = offset_sum;
336 
337   // Invariant 2 (trivial by the code above).
338   TC_DCHECK_EQ(concat_offset_.size(), embedding_matrices_.size());
339 
340   const int num_hidden_layers = model->GetNumHiddenLayers();
341   if (num_hidden_layers < 1) {
342     TC_LOG(ERROR) << "Wrong number of hidden layers: " << num_hidden_layers;
343     return;
344   }
345   hidden_weights_.resize(num_hidden_layers);
346   hidden_bias_.resize(num_hidden_layers);
347 
348   for (int i = 0; i < num_hidden_layers; ++i) {
349     const EmbeddingNetworkParams::Matrix matrix =
350         model->GetHiddenLayerMatrix(i);
351     const EmbeddingNetworkParams::Matrix bias = model->GetHiddenLayerBias(i);
352     if (!InitNonQuantizedMatrix(matrix, &hidden_weights_[i]) ||
353         !InitNonQuantizedVector(bias, &hidden_bias_[i])) {
354       TC_LOG(ERROR) << "Bad hidden layer #" << i;
355       return;
356     }
357   }
358 
359   if (!model->HasSoftmaxLayer()) {
360     TC_LOG(ERROR) << "Missing softmax layer";
361     return;
362   }
363   const EmbeddingNetworkParams::Matrix softmax = model->GetSoftmaxMatrix();
364   const EmbeddingNetworkParams::Matrix softmax_bias = model->GetSoftmaxBias();
365   if (!InitNonQuantizedMatrix(softmax, &softmax_weights_) ||
366       !InitNonQuantizedVector(softmax_bias, &softmax_bias_)) {
367     TC_LOG(ERROR) << "Bad softmax layer";
368     return;
369   }
370 
371   // Everything looks good.
372   valid_ = true;
373 }
374 
EmbeddingSize(int es_index) const375 int EmbeddingNetwork::EmbeddingSize(int es_index) const {
376   return embedding_matrices_[es_index]->dim();
377 }
378 
379 }  // namespace nlp_core
380 }  // namespace libtextclassifier
381