1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/common/embedding-network.h"
18
19 #include "lang_id/common/lite_base/integral-types.h"
20 #include "lang_id/common/lite_base/logging.h"
21
22 namespace libtextclassifier3 {
23 namespace mobile {
24 namespace {
25
CheckNoQuantization(const EmbeddingNetworkParams::Matrix & matrix)26 void CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
27 SAFTM_CHECK_EQ(static_cast<int>(QuantizationType::NONE),
28 static_cast<int>(matrix.quant_type))
29 << "Quantization not allowed here";
30 }
31
GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix & matrix)32 int GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix &matrix) {
33 int cols = matrix.cols;
34 QuantizationType quant_type = matrix.quant_type;
35 switch (quant_type) {
36 case QuantizationType::NONE:
37 return cols * sizeof(float);
38 case QuantizationType::UINT8:
39 return cols * sizeof(uint8);
40 case QuantizationType::UINT4:
41 SAFTM_DCHECK_EQ(cols % 2, 0) << "UINT4 with odd #cols = " << cols;
42 return cols / 2;
43 case QuantizationType::FLOAT16:
44 return cols * sizeof(float16);
45 default:
46 SAFTM_LOG(FATAL) << "Unknown quant type: "
47 << static_cast<int>(quant_type);
48 }
49 }
50
51 // Computes y = weights * Relu(x) + b where Relu is optionally applied.
52 //
53 // weights and b are the weight matrix, respectively the bias vector of a neural
54 // network layer.
55 //
56 // Note: in the research literature, usually Relu (the activation function) is
57 // the last part of a neural layer. From that perspective, this function
58 // computes the Relu part of the previous layer (if any) and next the first half
59 // (the computation of the state) for the current layer.
60 //
61 // Note: weights is expected to be the transposed version of the real weight
62 // matrix. Hence, instead of computing a linear combination of the columns of
63 // weights, we compute a linear combination of its rows; but we are mindful that
64 // these rows are the columns of the original matrix, hence the name
65 // weights_col_i in the code.
SparseReluProductPlusBias(bool apply_relu,const EmbeddingNetworkParams::Matrix & weights,const EmbeddingNetworkParams::Matrix & b,const std::vector<float> & x,std::vector<float> * y)66 void SparseReluProductPlusBias(bool apply_relu,
67 const EmbeddingNetworkParams::Matrix &weights,
68 const EmbeddingNetworkParams::Matrix &b,
69 const std::vector<float> &x,
70 std::vector<float> *y) {
71 // Initialize y to b. b is a column matrix (i.e., nb.cols == 1); we already
72 // CHECK-ed that the EmbeddingNetwork constructor.
73 const float *b_start = reinterpret_cast<const float *>(b.elements);
74 SAFTM_DCHECK_EQ(b.cols, 1);
75 y->assign(b_start, b_start + b.rows);
76
77 float *const y_data = y->data();
78 const int y_size = y->size();
79 SAFTM_CHECK_EQ(weights.cols, y_size);
80 const int x_size = x.size();
81 SAFTM_CHECK_EQ(weights.rows, x_size);
82
83 // NOTE: the code below reads x_size * y_size elements from weights; these
84 // reads are safe as long as weights.elements contains weights.rows *
85 // weights.cols elements (where the element size depends on the quantization
86 // type). That requirement is checked by the params provider, e.g., by
87 // EmbeddingNetworkParamsFromFlatbuffer.
88
89 // There is some code duplication between the two main cases of the switch
90 // below: the idea was to "lift" the switch outside the loops, to reduce the
91 // number of tests at runtime.
92 switch (weights.quant_type) {
93 case QuantizationType::NONE: {
94 // We compute a linear combination of the rows from |weights|, using
95 // elements of x (optionally, Relu(x)) as scaling factors (the i-th row
96 // gets multiplied by x[i] before being added with the other rows). Note:
97 // elements of |weights| are stored in row-major order: first the elements
98 // of row #0, next the elements of row #1, etc. In the comments below, we
99 // write "weights[i][j]" to refer to the j-th element from the i-th row of
100 // weights.
101 const float *weight_ptr =
102 reinterpret_cast<const float *>(weights.elements);
103 for (int i = 0; i < x_size; ++i) {
104 // Invariant 1: weight_ptr points to the beginning of the i-th row from
105 // weights (i.e., weights[i][0]).
106 const float scale = x[i];
107 if (!apply_relu || (scale > 0)) {
108 for (int j = 0; j < y_size; ++j, ++weight_ptr) {
109 // Invariant 2: weight_ptr points to weights[i][j].
110 y_data[j] += (*weight_ptr) * scale;
111 }
112 } else {
113 // We don't update y_data, but we still have to move weight_ptr to the
114 // next row (to satisfy Invariant 1). We do this by adding y_size ==
115 // weights.cols() (see earlier CHECK_EQ).
116 weight_ptr += y_size;
117 }
118 }
119 break;
120 }
121 case QuantizationType::FLOAT16: {
122 // See comments for the QuantizationType::NONE case: the code is almost
123 // identical, except for float16 (instead of float) and the Float16To32
124 // conversion. We could unify these two cases using a template, but since
125 // this is a critical loop, don't want to risk that e.g., inlining of the
126 // conversion function doesn't happen.
127 const float16 *weight_ptr =
128 reinterpret_cast<const float16 *>(weights.elements);
129 for (int i = 0; i < x_size; ++i) {
130 const float scale = x[i];
131 if (!apply_relu || (scale > 0)) {
132 for (int j = 0; j < y_size; ++j, ++weight_ptr) {
133 y_data[j] += Float16To32(*weight_ptr) * scale;
134 }
135 } else {
136 weight_ptr += y_size;
137 }
138 }
139 break;
140 }
141 default:
142 SAFTM_LOG(FATAL) << "Unsupported weights quantization type: "
143 << static_cast<int>(weights.quant_type);
144 }
145 }
146 } // namespace
147
ConcatEmbeddings(const std::vector<FeatureVector> & feature_vectors,std::vector<float> * concat) const148 void EmbeddingNetwork::ConcatEmbeddings(
149 const std::vector<FeatureVector> &feature_vectors,
150 std::vector<float> *concat) const {
151 concat->resize(concat_layer_size_);
152
153 // "es_index" stands for "embedding space index".
154 for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
155 const int concat_offset = concat_offset_[es_index];
156
157 const EmbeddingNetworkParams::Matrix &embedding_matrix =
158 embedding_matrices_[es_index];
159 const int embedding_dim = embedding_matrix.cols;
160 const int embedding_row_size_in_bytes =
161 embedding_row_size_in_bytes_[es_index];
162
163 const FeatureVector &feature_vector = feature_vectors[es_index];
164 const int num_features = feature_vector.size();
165 for (int fi = 0; fi < num_features; ++fi) {
166 const FeatureType *feature_type = feature_vector.type(fi);
167 int feature_offset = concat_offset + feature_type->base() * embedding_dim;
168 SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size());
169
170 // Weighted embeddings will be added starting from this address.
171 float *concat_ptr = concat->data() + feature_offset;
172
173 // Multiplier for each embedding weight. Includes feature weight (for
174 // continuous features) and quantization scale (for quantized embeddings).
175 float multiplier;
176 int feature_id;
177 const FeatureValue feature_value = feature_vector.value(fi);
178 if (feature_type->is_continuous()) {
179 // Continuous features (encoded as FloatFeatureValue).
180 FloatFeatureValue float_feature_value(feature_value);
181 feature_id = float_feature_value.id;
182 multiplier = float_feature_value.weight;
183 } else {
184 // Discrete features: every present feature has implicit value 1.0.
185 feature_id = feature_value;
186 multiplier = 1.0;
187 }
188
189 SAFTM_CHECK_GE(feature_id, 0);
190 SAFTM_CHECK_LT(feature_id, embedding_matrix.rows);
191
192 // Pointer to float / uint8 weights for relevant embedding.
193 const void *embedding_data =
194 (reinterpret_cast<const char *>(embedding_matrix.elements) +
195 feature_id * embedding_row_size_in_bytes);
196
197 switch (embedding_matrix.quant_type) {
198 case QuantizationType::NONE: {
199 const float *weights =
200 reinterpret_cast<const float *>(embedding_data);
201 for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
202 *concat_ptr += *weights * multiplier;
203 }
204 break;
205 }
206 case QuantizationType::UINT8: {
207 multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
208 const uint8 *quant_weights =
209 reinterpret_cast<const uint8 *>(embedding_data);
210 for (int i = 0; i < embedding_dim;
211 ++i, ++quant_weights, ++concat_ptr) {
212 // 128 is bias for UINT8 quantization.
213 *concat_ptr +=
214 (static_cast<int>(*quant_weights) - 128) * multiplier;
215 }
216 break;
217 }
218 case QuantizationType::UINT4: {
219 multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
220 const uint8 *quant_weights =
221 reinterpret_cast<const uint8 *>(embedding_data);
222 for (int i = 0; i < embedding_dim / 2; ++i, ++quant_weights) {
223 const uint8 qq = *quant_weights;
224 concat_ptr[0] +=
225 (static_cast<int>((qq & 0xF0) | 0x08) - 128) * multiplier;
226 concat_ptr[1] +=
227 (static_cast<int>(((qq & 0x0F) << 4) | 0x08) - 128) *
228 multiplier;
229 concat_ptr += 2;
230 }
231 break;
232 }
233 default:
234 // We already checked (in GetMatrixRowSizeInBytes) that each embedding
235 // matrix has a known quantization type. Hence, DLOG is enough here.
236 SAFTM_DLOG(ERROR) << "Unknown embeddings quantization type "
237 << static_cast<int>(embedding_matrix.quant_type);
238 break;
239 }
240 }
241 }
242 }
243
ComputeFinalScores(const std::vector<FeatureVector> & features,std::vector<float> * scores) const244 void EmbeddingNetwork::ComputeFinalScores(
245 const std::vector<FeatureVector> &features,
246 std::vector<float> *scores) const {
247 ComputeFinalScores(features, {}, scores);
248 }
249
ComputeFinalScores(const std::vector<FeatureVector> & features,const std::vector<float> & extra_inputs,std::vector<float> * scores) const250 void EmbeddingNetwork::ComputeFinalScores(
251 const std::vector<FeatureVector> &features,
252 const std::vector<float> &extra_inputs, std::vector<float> *scores) const {
253 // Construct the input layer for our feed-forward neural network (FFNN).
254 std::vector<float> input;
255 ConcatEmbeddings(features, &input);
256 if (!extra_inputs.empty()) {
257 input.reserve(input.size() + extra_inputs.size());
258 for (int i = 0; i < extra_inputs.size(); i++) {
259 input.push_back(extra_inputs[i]);
260 }
261 }
262
263 // Propagate input through all layers of our FFNN.
264
265 // Alternating storage for activations of the different layers. We can't use
266 // a single vector because all activations of the previous layer are required
267 // when computing the activations of the next one.
268 std::vector<float> storage[2];
269 const std::vector<float> *v_in = &input;
270 const int num_layers = layer_weights_.size();
271 for (int i = 0; i < num_layers; ++i) {
272 std::vector<float> *v_out = nullptr;
273 if (i == num_layers - 1) {
274 // Final layer: write results directly into |scores|.
275 v_out = scores;
276 } else {
277 // Hidden layer: write results into the alternating storage. The i % 2
278 // trick ensures the alternation.
279 v_out = &(storage[i % 2]);
280 }
281 const bool apply_relu = i > 0;
282 SparseReluProductPlusBias(
283 apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out);
284 v_in = v_out;
285 }
286 }
287
EmbeddingNetwork(const EmbeddingNetworkParams * model)288 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
289 : model_(model) {
290 int offset_sum = 0;
291 for (int i = 0; i < model_->embedding_num_features_size(); ++i) {
292 concat_offset_.push_back(offset_sum);
293 EmbeddingNetworkParams::Matrix matrix = model_->GetEmbeddingMatrix(i);
294 offset_sum += matrix.cols * model_->embedding_num_features(i);
295
296 // NOTE: each Matrix is a small struct that doesn't own the actual matrix
297 // weights. Hence, the push_back below is fast.
298 embedding_matrices_.push_back(matrix);
299 embedding_row_size_in_bytes_.push_back(GetMatrixRowSizeInBytes(matrix));
300 }
301 concat_layer_size_ = offset_sum;
302
303 SAFTM_CHECK_EQ(model_->hidden_size(), model_->hidden_bias_size());
304 for (int i = 0; i < model_->hidden_size(); ++i) {
305 layer_weights_.push_back(model_->GetHiddenLayerMatrix(i));
306
307 EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i);
308 SAFTM_CHECK_EQ(1, bias.cols);
309 CheckNoQuantization(bias);
310 layer_bias_.push_back(bias);
311 }
312
313 SAFTM_CHECK(model_->HasSoftmax());
314 layer_weights_.push_back(model_->GetSoftmaxMatrix());
315
316 EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias();
317 SAFTM_CHECK_EQ(1, softmax_bias.cols);
318 CheckNoQuantization(softmax_bias);
319 layer_bias_.push_back(softmax_bias);
320 }
321
322 } // namespace mobile
323 } // namespace nlp_saft
324