1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
18
19 #include "lang_id/common/lite_base/endian.h"
20 #include "lang_id/common/lite_base/logging.h"
21 #include "lang_id/common/lite_base/macros.h"
22
23 namespace libtextclassifier3 {
24 namespace mobile {
25
26 namespace {
27 // Returns true if and only if ptr points to a location inside allowed_range.
IsPointerInRange(const char * ptr,StringPiece allowed_range)28 bool IsPointerInRange(const char *ptr, StringPiece allowed_range) {
29 return (ptr >= allowed_range.data()) &&
30 (ptr < (allowed_range.data() + allowed_range.size()));
31 }
32
33 // Returns true if and only if the memory range [start, start +
34 // range_size_in_bytes) is included inside allowed_range.
35 //
36 // Special case: if range_size_in_bytes == 0 (empty range) then we require that
37 // start is nullptr or in the allowed_range.
IsMemoryRangeValid(const void * start,int range_size_in_bytes,StringPiece allowed_range)38 bool IsMemoryRangeValid(const void *start, int range_size_in_bytes,
39 StringPiece allowed_range) {
40 const char *begin = reinterpret_cast<const char *>(start);
41 if (range_size_in_bytes < 0) {
42 return false;
43 }
44 if (range_size_in_bytes == 0) {
45 return (start == nullptr) || IsPointerInRange(begin, allowed_range);
46 }
47 const char *inclusive_end = begin + (range_size_in_bytes - 1);
48 return (begin <= inclusive_end) && IsPointerInRange(begin, allowed_range) &&
49 IsPointerInRange(inclusive_end, allowed_range);
50 }
51
VerifyQuantizationScales(EmbeddingNetworkParams::Matrix matrix,StringPiece bytes)52 bool VerifyQuantizationScales(EmbeddingNetworkParams::Matrix matrix,
53 StringPiece bytes) {
54 if (matrix.quant_scales == nullptr) {
55 SAFTM_LOG(ERROR) << "Quantization type "
56 << static_cast<int>(matrix.quant_type)
57 << "; but no quantization scales";
58 return false;
59 }
60 bool valid_scales = IsMemoryRangeValid(matrix.quant_scales,
61 matrix.rows * sizeof(float16), bytes);
62 if (!valid_scales) {
63 SAFTM_LOG(ERROR) << "quantization scales not fully inside bytes";
64 return false;
65 }
66 return true;
67 }
68
69 // Returns false if we detect a problem with |matrix|, true otherwise. E.g., we
70 // check that the array that starts at pointer matrix.elements is fully inside
71 // |bytes| (the range of bytes passed to the
72 // EmbeddingNetworkParamsFromFlatbuffer constructor).
VerifyMatrix(EmbeddingNetworkParams::Matrix matrix,StringPiece bytes)73 bool VerifyMatrix(EmbeddingNetworkParams::Matrix matrix, StringPiece bytes) {
74 if ((matrix.rows < 0) || (matrix.cols < 0)) {
75 SAFTM_LOG(ERROR) << "Wrong matrix geometry: " << matrix.rows << " x "
76 << matrix.cols;
77 return false;
78 }
79
80 const int num_elements = matrix.rows * matrix.cols;
81
82 // Number of bytes occupied by the num_elements elements that start at address
83 // matrix.elements.
84 int element_range_size_in_bytes = 0;
85 switch (matrix.quant_type) {
86 case QuantizationType::NONE:
87 element_range_size_in_bytes = num_elements * sizeof(float);
88 break;
89 case QuantizationType::UINT8: {
90 element_range_size_in_bytes = num_elements;
91 if (!VerifyQuantizationScales(matrix, bytes)) {
92 return false;
93 }
94 break;
95 }
96 case QuantizationType::UINT4: {
97 if (matrix.cols % 2 != 0) {
98 SAFTM_LOG(ERROR) << "UINT4 doesn't work with odd #cols" << matrix.cols;
99 return false;
100 }
101 element_range_size_in_bytes = num_elements / 2;
102 if (!VerifyQuantizationScales(matrix, bytes)) {
103 return false;
104 }
105 break;
106 }
107 case QuantizationType::FLOAT16: {
108 element_range_size_in_bytes = num_elements * sizeof(float16);
109
110 // No need to verify the scales: FLOAT16 quantization does not use scales.
111 break;
112 }
113 default:
114 SAFTM_LOG(ERROR) << "Unsupported quantization type "
115 << static_cast<int>(matrix.quant_type);
116 return false;
117 }
118 if (matrix.elements == nullptr) {
119 SAFTM_LOG(ERROR) << "matrix.elements == nullptr";
120 return false;
121 }
122 bool valid =
123 IsMemoryRangeValid(matrix.elements, element_range_size_in_bytes, bytes);
124 if (!valid) {
125 SAFTM_LOG(ERROR) << "elements not fully inside bytes";
126 return false;
127 }
128 return true;
129 }
130
131 // Checks the geometry of the network layer represented by |weights| and |bias|,
132 // assuming the input to this layer has size |input_size|. Returns false if we
133 // detect any problem, true otherwise.
GoodLayerGeometry(int input_size,const EmbeddingNetworkParams::Matrix & weights,const EmbeddingNetworkParams::Matrix & bias)134 bool GoodLayerGeometry(int input_size,
135 const EmbeddingNetworkParams::Matrix &weights,
136 const EmbeddingNetworkParams::Matrix &bias) {
137 if (weights.rows != input_size) {
138 SAFTM_LOG(ERROR) << "#rows " << weights.rows << " != " << input_size;
139 return false;
140 }
141 if ((bias.rows != 1) && (bias.cols != 1)) {
142 SAFTM_LOG(ERROR) << "bad bias vector geometry: " << bias.rows << " x "
143 << bias.cols;
144 return false;
145 }
146 int bias_dimension = bias.rows * bias.cols;
147 if (weights.cols != bias_dimension) {
148 SAFTM_LOG(ERROR) << "#cols " << weights.cols << " != " << bias_dimension;
149 return false;
150 }
151 return true;
152 }
153 } // namespace
154
EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes)155 EmbeddingNetworkParamsFromFlatbuffer::EmbeddingNetworkParamsFromFlatbuffer(
156 StringPiece bytes) {
157 // We expect valid_ to be initialized to false at this point. We set it to
158 // true only if we successfully complete all initialization. On error, we
159 // return early, leaving valid_ set to false.
160 SAFTM_DCHECK(!valid_);
161
162 // NOTE: current EmbeddingNetworkParams API works only on little-endian
163 // machines. Fortunately, all modern devices are little-endian so, instead of
164 // a costly API change, we support only the little-endian case.
165 //
166 // Technical explanation: for each Matrix, our API provides a pointer to the
167 // matrix elements (see Matrix field |elements|). For unquantized matrices,
168 // that's a const float *pointer; the client code (e.g., Neurosis) uses those
169 // floats directly. That is correct if the EmbeddingNetworkParams come from a
170 // proto, where the proto parsing already handled the endianness differences.
171 // But in the flatbuffer case, that's a pointer to floats in little-endian
172 // format (flatbuffers always use little-endian). If our API provided access
173 // to only one element at a time, the accessor method could swap the bytes "on
174 // the fly", using temporary variables. Instead, our API provides a pointer
175 // to all elements: as their number is variable (and underlying data is
176 // immutable), we can't ensure the bytes of all those elements are swapped
177 // without extra memory allocation to store the swapped bytes (which is what
178 // using flatbuffers is supposed to prevent).
179 if (!LittleEndian::IsLittleEndian()) {
180 SAFTM_LOG(INFO) << "Not a little-endian machine";
181 return;
182 }
183
184 const uint8_t *start = reinterpret_cast<const uint8_t *>(bytes.data());
185 if (start == nullptr) {
186 // Note: as |bytes| is expected to be a valid EmbeddingNetwork flatbuffer,
187 // it should contain the 4-char identifier "NS00" (or a later version). It
188 // can't be empty; hence StringPiece(nullptr, 0) is not legal here.
189 SAFTM_LOG(ERROR) << "nullptr bytes";
190 return;
191 }
192 flatbuffers::Verifier verifier(start, bytes.size());
193 if (!saft_fbs::VerifyEmbeddingNetworkBuffer(verifier)) {
194 SAFTM_LOG(ERROR) << "Not a valid EmbeddingNetwork flatbuffer";
195 return;
196 }
197 network_ = saft_fbs::GetEmbeddingNetwork(start);
198 if (network_ == nullptr) {
199 SAFTM_LOG(ERROR) << "Unable to interpret bytes as a flatbuffer";
200 return;
201 }
202
203 // Perform a few extra checks before declaring this object valid.
204 valid_ = ValidityChecking(bytes);
205 }
206
ValidityChecking(StringPiece bytes) const207 bool EmbeddingNetworkParamsFromFlatbuffer::ValidityChecking(
208 StringPiece bytes) const {
209 int input_size = 0;
210 for (int i = 0; i < embeddings_size(); ++i) {
211 Matrix embeddings = GetEmbeddingMatrix(i);
212 if (!VerifyMatrix(embeddings, bytes)) {
213 SAFTM_LOG(ERROR) << "Bad embedding matrix #" << i;
214 return false;
215 }
216 input_size += embedding_num_features(i) * embeddings.cols;
217 }
218 int current_size = input_size;
219 for (int i = 0; i < hidden_size(); ++i) {
220 Matrix weights = GetHiddenLayerMatrix(i);
221 if (!VerifyMatrix(weights, bytes)) {
222 SAFTM_LOG(ERROR) << "Bad weights matrix for hidden layer #" << i;
223 return false;
224 }
225 Matrix bias = GetHiddenLayerBias(i);
226 if (!VerifyMatrix(bias, bytes)) {
227 SAFTM_LOG(ERROR) << "Bad bias vector for hidden layer #" << i;
228 return false;
229 }
230 if (!GoodLayerGeometry(current_size, weights, bias)) {
231 SAFTM_LOG(ERROR) << "Bad geometry for hidden layer #" << i;
232 return false;
233 }
234 current_size = weights.cols;
235 }
236
237 if (HasSoftmax()) {
238 Matrix weights = GetSoftmaxMatrix();
239 if (!VerifyMatrix(weights, bytes)) {
240 SAFTM_LOG(ERROR) << "Bad weights matrix for softmax";
241 return false;
242 }
243 Matrix bias = GetSoftmaxBias();
244 if (!VerifyMatrix(bias, bytes)) {
245 SAFTM_LOG(ERROR) << "Bad bias vector for softmax";
246 return false;
247 }
248 if (!GoodLayerGeometry(current_size, weights, bias)) {
249 SAFTM_LOG(ERROR) << "Bad geometry for softmax layer";
250 return false;
251 }
252 }
253 return true;
254 }
255
256 // static
InRangeIndex(int index,int limit,const char * info)257 bool EmbeddingNetworkParamsFromFlatbuffer::InRangeIndex(int index, int limit,
258 const char *info) {
259 if ((index >= 0) && (index < limit)) {
260 return true;
261 } else {
262 SAFTM_LOG(ERROR) << info << " index " << index << " outside range [0, "
263 << limit << ")";
264 return false;
265 }
266 }
267
SafeGetNumInputChunks() const268 int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumInputChunks() const {
269 const auto *input_chunks = network_->input_chunks();
270 if (input_chunks == nullptr) {
271 SAFTM_LOG(ERROR) << "nullptr input_chunks";
272 return 0;
273 }
274 return input_chunks->size();
275 }
276
277 const saft_fbs::InputChunk *
SafeGetInputChunk(int i) const278 EmbeddingNetworkParamsFromFlatbuffer::SafeGetInputChunk(int i) const {
279 if (!InRangeIndex(i, SafeGetNumInputChunks(), "input chunks")) {
280 return nullptr;
281 }
282 const auto *input_chunks = network_->input_chunks();
283 if (input_chunks == nullptr) {
284 // Execution should not reach this point, due to how SafeGetNumInputChunks()
285 // is implemented. Still, just to be sure:
286 SAFTM_LOG(ERROR) << "nullptr input_chunks";
287 return nullptr;
288 }
289 const saft_fbs::InputChunk *input_chunk = input_chunks->Get(i);
290 if (input_chunk == nullptr) {
291 SAFTM_LOG(ERROR) << "nullptr input chunk #" << i;
292 }
293 return input_chunk;
294 }
295
296 const saft_fbs::Matrix *
SafeGetEmbeddingMatrix(int i) const297 EmbeddingNetworkParamsFromFlatbuffer::SafeGetEmbeddingMatrix(int i) const {
298 const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
299 if (input_chunk == nullptr) return nullptr;
300 const saft_fbs::Matrix *matrix = input_chunk->embedding();
301 if (matrix == nullptr) {
302 SAFTM_LOG(ERROR) << "nullptr embeding matrix #" << i;
303 }
304 return matrix;
305 }
306
SafeGetNumLayers() const307 int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumLayers() const {
308 const auto *layers = network_->layers();
309 if (layers == nullptr) {
310 SAFTM_LOG(ERROR) << "nullptr layers";
311 return 0;
312 }
313 return layers->size();
314 }
315
SafeGetLayer(int i) const316 const saft_fbs::NeuralLayer *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayer(
317 int i) const {
318 if (!InRangeIndex(i, SafeGetNumLayers(), "layer")) {
319 return nullptr;
320 }
321 const auto *layers = network_->layers();
322 if (layers == nullptr) {
323 // Execution should not reach this point, due to how SafeGetNumLayers()
324 // is implemented. Still, just to be sure:
325 SAFTM_LOG(ERROR) << "nullptr layers";
326 return nullptr;
327 }
328 const saft_fbs::NeuralLayer *layer = layers->Get(i);
329 if (layer == nullptr) {
330 SAFTM_LOG(ERROR) << "nullptr layer #" << i;
331 }
332 return layer;
333 }
334
335 const saft_fbs::Matrix *
SafeGetLayerWeights(int i) const336 EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerWeights(int i) const {
337 const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
338 if (layer == nullptr) return nullptr;
339 const saft_fbs::Matrix *weights = layer->weights();
340 if (weights == nullptr) {
341 SAFTM_LOG(ERROR) << "nullptr weights for layer #" << i;
342 }
343 return weights;
344 }
345
SafeGetLayerBias(int i) const346 const saft_fbs::Matrix *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerBias(
347 int i) const {
348 const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
349 if (layer == nullptr) return nullptr;
350 const saft_fbs::Matrix *bias = layer->bias();
351 if (bias == nullptr) {
352 SAFTM_LOG(ERROR) << "nullptr bias for layer #" << i;
353 }
354 return bias;
355 }
356
357 // static
SafeGetValues(const saft_fbs::Matrix * matrix)358 const float *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValues(
359 const saft_fbs::Matrix *matrix) {
360 if (matrix == nullptr) return nullptr;
361 const flatbuffers::Vector<float> *values = matrix->values();
362 if (values == nullptr) {
363 SAFTM_LOG(ERROR) << "nullptr values";
364 }
365 return values->data();
366 }
367
368 // static
SafeGetQuantizedValues(const saft_fbs::Matrix * matrix)369 const uint8_t *EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizedValues(
370 const saft_fbs::Matrix *matrix) {
371 if (matrix == nullptr) return nullptr;
372 const flatbuffers::Vector<uint8_t> *quantized_values =
373 matrix->quantized_values();
374 if (quantized_values == nullptr) {
375 SAFTM_LOG(ERROR) << "nullptr quantized_values";
376 }
377 return quantized_values->data();
378 }
379
380 // static
SafeGetScales(const saft_fbs::Matrix * matrix)381 const float16 *EmbeddingNetworkParamsFromFlatbuffer::SafeGetScales(
382 const saft_fbs::Matrix *matrix) {
383 if (matrix == nullptr) return nullptr;
384 const flatbuffers::Vector<uint16_t> *scales = matrix->scales();
385 if (scales == nullptr) {
386 SAFTM_LOG(ERROR) << "nullptr scales";
387 }
388 return scales->data();
389 }
390
391 const saft_fbs::NeuralLayer *
SafeGetSoftmaxLayer() const392 EmbeddingNetworkParamsFromFlatbuffer::SafeGetSoftmaxLayer() const {
393 int num_layers = SafeGetNumLayers();
394 if (num_layers <= 0) {
395 SAFTM_LOG(ERROR) << "No softmax layer";
396 return nullptr;
397 }
398 return SafeGetLayer(num_layers - 1);
399 }
400
SafeGetQuantizationType(const saft_fbs::Matrix * matrix) const401 QuantizationType EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizationType(
402 const saft_fbs::Matrix *matrix) const {
403 if (matrix == nullptr) {
404 return QuantizationType::NONE;
405 }
406 saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
407
408 // Conversion from nlp_saft::saft_fbs::QuantizationType to
409 // nlp_saft::QuantizationType (due to legacy reasons, we have both).
410 switch (quantization_type) {
411 case saft_fbs::QuantizationType_NONE:
412 return QuantizationType::NONE;
413 case saft_fbs::QuantizationType_UINT8:
414 return QuantizationType::UINT8;
415 case saft_fbs::QuantizationType_UINT4:
416 return QuantizationType::UINT4;
417 case saft_fbs::QuantizationType_FLOAT16:
418 return QuantizationType::FLOAT16;
419 default:
420 SAFTM_LOG(ERROR) << "Unsupported quantization type "
421 << static_cast<int>(quantization_type);
422 return QuantizationType::NONE;
423 }
424 }
425
SafeGetValuesOfMatrix(const saft_fbs::Matrix * matrix) const426 const void *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValuesOfMatrix(
427 const saft_fbs::Matrix *matrix) const {
428 if (matrix == nullptr) {
429 return nullptr;
430 }
431 saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
432 switch (quantization_type) {
433 case saft_fbs::QuantizationType_NONE:
434 return SafeGetValues(matrix);
435 case saft_fbs::QuantizationType_UINT8:
436 SAFTM_FALLTHROUGH_INTENDED;
437 case saft_fbs::QuantizationType_UINT4:
438 SAFTM_FALLTHROUGH_INTENDED;
439 case saft_fbs::QuantizationType_FLOAT16:
440 return SafeGetQuantizedValues(matrix);
441 default:
442 SAFTM_LOG(ERROR) << "Unsupported quantization type "
443 << static_cast<int>(quantization_type);
444 return nullptr;
445 }
446 }
447
448 } // namespace mobile
449 } // namespace nlp_saft
450