• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
18 
19 #include "lang_id/common/lite_base/endian.h"
20 #include "lang_id/common/lite_base/logging.h"
21 #include "lang_id/common/lite_base/macros.h"
22 
23 namespace libtextclassifier3 {
24 namespace mobile {
25 
26 namespace {
27 // Returns true if and only if ptr points to a location inside allowed_range.
IsPointerInRange(const char * ptr,StringPiece allowed_range)28 bool IsPointerInRange(const char *ptr, StringPiece allowed_range) {
29   return (ptr >= allowed_range.data()) &&
30          (ptr < (allowed_range.data() + allowed_range.size()));
31 }
32 
33 // Returns true if and only if the memory range [start, start +
34 // range_size_in_bytes) is included inside allowed_range.
35 //
36 // Special case: if range_size_in_bytes == 0 (empty range) then we require that
37 // start is nullptr or in the allowed_range.
IsMemoryRangeValid(const void * start,int range_size_in_bytes,StringPiece allowed_range)38 bool IsMemoryRangeValid(const void *start, int range_size_in_bytes,
39                         StringPiece allowed_range) {
40   const char *begin = reinterpret_cast<const char *>(start);
41   if (range_size_in_bytes < 0) {
42     return false;
43   }
44   if (range_size_in_bytes == 0) {
45     return (start == nullptr) || IsPointerInRange(begin, allowed_range);
46   }
47   const char *inclusive_end = begin + (range_size_in_bytes - 1);
48   return (begin <= inclusive_end) && IsPointerInRange(begin, allowed_range) &&
49          IsPointerInRange(inclusive_end, allowed_range);
50 }
51 
VerifyQuantizationScales(EmbeddingNetworkParams::Matrix matrix,StringPiece bytes)52 bool VerifyQuantizationScales(EmbeddingNetworkParams::Matrix matrix,
53                               StringPiece bytes) {
54   if (matrix.quant_scales == nullptr) {
55     SAFTM_LOG(ERROR) << "Quantization type "
56                      << static_cast<int>(matrix.quant_type)
57                      << "; but no quantization scales";
58     return false;
59   }
60   bool valid_scales = IsMemoryRangeValid(matrix.quant_scales,
61                                          matrix.rows * sizeof(float16), bytes);
62   if (!valid_scales) {
63     SAFTM_LOG(ERROR) << "quantization scales not fully inside bytes";
64     return false;
65   }
66   return true;
67 }
68 
69 // Returns false if we detect a problem with |matrix|, true otherwise.  E.g., we
70 // check that the array that starts at pointer matrix.elements is fully inside
71 // |bytes| (the range of bytes passed to the
72 // EmbeddingNetworkParamsFromFlatbuffer constructor).
VerifyMatrix(EmbeddingNetworkParams::Matrix matrix,StringPiece bytes)73 bool VerifyMatrix(EmbeddingNetworkParams::Matrix matrix, StringPiece bytes) {
74   if ((matrix.rows < 0) || (matrix.cols < 0)) {
75     SAFTM_LOG(ERROR) << "Wrong matrix geometry: " << matrix.rows << " x "
76                      << matrix.cols;
77     return false;
78   }
79 
80   const int num_elements = matrix.rows * matrix.cols;
81 
82   // Number of bytes occupied by the num_elements elements that start at address
83   // matrix.elements.
84   int element_range_size_in_bytes = 0;
85   switch (matrix.quant_type) {
86     case QuantizationType::NONE:
87       element_range_size_in_bytes = num_elements * sizeof(float);
88       break;
89     case QuantizationType::UINT8: {
90       element_range_size_in_bytes = num_elements;
91       if (!VerifyQuantizationScales(matrix, bytes)) {
92         return false;
93       }
94       break;
95     }
96     case QuantizationType::UINT4: {
97       if (matrix.cols % 2 != 0) {
98         SAFTM_LOG(ERROR) << "UINT4 doesn't work with odd #cols" << matrix.cols;
99         return false;
100       }
101       element_range_size_in_bytes = num_elements / 2;
102       if (!VerifyQuantizationScales(matrix, bytes)) {
103         return false;
104       }
105       break;
106     }
107     case QuantizationType::FLOAT16: {
108       element_range_size_in_bytes = num_elements * sizeof(float16);
109 
110       // No need to verify the scales: FLOAT16 quantization does not use scales.
111       break;
112     }
113     default:
114       SAFTM_LOG(ERROR) << "Unsupported quantization type "
115                        << static_cast<int>(matrix.quant_type);
116       return false;
117   }
118   if (matrix.elements == nullptr) {
119     SAFTM_LOG(ERROR) << "matrix.elements == nullptr";
120     return false;
121   }
122   bool valid =
123       IsMemoryRangeValid(matrix.elements, element_range_size_in_bytes, bytes);
124   if (!valid) {
125     SAFTM_LOG(ERROR) << "elements not fully inside bytes";
126     return false;
127   }
128   return true;
129 }
130 
131 // Checks the geometry of the network layer represented by |weights| and |bias|,
132 // assuming the input to this layer has size |input_size|.  Returns false if we
133 // detect any problem, true otherwise.
GoodLayerGeometry(int input_size,const EmbeddingNetworkParams::Matrix & weights,const EmbeddingNetworkParams::Matrix & bias)134 bool GoodLayerGeometry(int input_size,
135                        const EmbeddingNetworkParams::Matrix &weights,
136                        const EmbeddingNetworkParams::Matrix &bias) {
137   if (weights.rows != input_size) {
138     SAFTM_LOG(ERROR) << "#rows " << weights.rows << " != " << input_size;
139     return false;
140   }
141   if ((bias.rows != 1) && (bias.cols != 1)) {
142     SAFTM_LOG(ERROR) << "bad bias vector geometry: " << bias.rows << " x "
143                      << bias.cols;
144     return false;
145   }
146   int bias_dimension = bias.rows * bias.cols;
147   if (weights.cols != bias_dimension) {
148     SAFTM_LOG(ERROR) << "#cols " << weights.cols << " != " << bias_dimension;
149     return false;
150   }
151   return true;
152 }
153 }  // namespace
154 
EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes)155 EmbeddingNetworkParamsFromFlatbuffer::EmbeddingNetworkParamsFromFlatbuffer(
156     StringPiece bytes) {
157   // We expect valid_ to be initialized to false at this point.  We set it to
158   // true only if we successfully complete all initialization.  On error, we
159   // return early, leaving valid_ set to false.
160   SAFTM_DCHECK(!valid_);
161 
162   // NOTE: current EmbeddingNetworkParams API works only on little-endian
163   // machines.  Fortunately, all modern devices are little-endian so, instead of
164   // a costly API change, we support only the little-endian case.
165   //
166   // Technical explanation: for each Matrix, our API provides a pointer to the
167   // matrix elements (see Matrix field |elements|).  For unquantized matrices,
168   // that's a const float *pointer; the client code (e.g., Neurosis) uses those
169   // floats directly.  That is correct if the EmbeddingNetworkParams come from a
170   // proto, where the proto parsing already handled the endianness differences.
171   // But in the flatbuffer case, that's a pointer to floats in little-endian
172   // format (flatbuffers always use little-endian).  If our API provided access
173   // to only one element at a time, the accessor method could swap the bytes "on
174   // the fly", using temporary variables.  Instead, our API provides a pointer
175   // to all elements: as their number is variable (and underlying data is
176   // immutable), we can't ensure the bytes of all those elements are swapped
177   // without extra memory allocation to store the swapped bytes (which is what
178   // using flatbuffers is supposed to prevent).
179   if (!LittleEndian::IsLittleEndian()) {
180     SAFTM_LOG(INFO) << "Not a little-endian machine";
181     return;
182   }
183 
184   const uint8_t *start = reinterpret_cast<const uint8_t *>(bytes.data());
185   if (start == nullptr) {
186     // Note: as |bytes| is expected to be a valid EmbeddingNetwork flatbuffer,
187     // it should contain the 4-char identifier "NS00" (or a later version).  It
188     // can't be empty; hence StringPiece(nullptr, 0) is not legal here.
189     SAFTM_LOG(ERROR) << "nullptr bytes";
190     return;
191   }
192   flatbuffers::Verifier verifier(start, bytes.size());
193   if (!saft_fbs::VerifyEmbeddingNetworkBuffer(verifier)) {
194     SAFTM_LOG(ERROR) << "Not a valid EmbeddingNetwork flatbuffer";
195     return;
196   }
197   network_ = saft_fbs::GetEmbeddingNetwork(start);
198   if (network_ == nullptr) {
199     SAFTM_LOG(ERROR) << "Unable to interpret bytes as a flatbuffer";
200     return;
201   }
202 
203   // Perform a few extra checks before declaring this object valid.
204   valid_ = ValidityChecking(bytes);
205 }
206 
ValidityChecking(StringPiece bytes) const207 bool EmbeddingNetworkParamsFromFlatbuffer::ValidityChecking(
208     StringPiece bytes) const {
209   int input_size = 0;
210   for (int i = 0; i < embeddings_size(); ++i) {
211     Matrix embeddings = GetEmbeddingMatrix(i);
212     if (!VerifyMatrix(embeddings, bytes)) {
213       SAFTM_LOG(ERROR) << "Bad embedding matrix #" << i;
214       return false;
215     }
216     input_size += embedding_num_features(i) * embeddings.cols;
217   }
218   int current_size = input_size;
219   for (int i = 0; i < hidden_size(); ++i) {
220     Matrix weights = GetHiddenLayerMatrix(i);
221     if (!VerifyMatrix(weights, bytes)) {
222       SAFTM_LOG(ERROR) << "Bad weights matrix for hidden layer #" << i;
223       return false;
224     }
225     Matrix bias = GetHiddenLayerBias(i);
226     if (!VerifyMatrix(bias, bytes)) {
227       SAFTM_LOG(ERROR) << "Bad bias vector for hidden layer #" << i;
228       return false;
229     }
230     if (!GoodLayerGeometry(current_size, weights, bias)) {
231       SAFTM_LOG(ERROR) << "Bad geometry for hidden layer #" << i;
232       return false;
233     }
234     current_size = weights.cols;
235   }
236 
237   if (HasSoftmax()) {
238     Matrix weights = GetSoftmaxMatrix();
239     if (!VerifyMatrix(weights, bytes)) {
240       SAFTM_LOG(ERROR) << "Bad weights matrix for softmax";
241       return false;
242     }
243     Matrix bias = GetSoftmaxBias();
244     if (!VerifyMatrix(bias, bytes)) {
245       SAFTM_LOG(ERROR) << "Bad bias vector for softmax";
246       return false;
247     }
248     if (!GoodLayerGeometry(current_size, weights, bias)) {
249       SAFTM_LOG(ERROR) << "Bad geometry for softmax layer";
250       return false;
251     }
252   }
253   return true;
254 }
255 
256 // static
InRangeIndex(int index,int limit,const char * info)257 bool EmbeddingNetworkParamsFromFlatbuffer::InRangeIndex(int index, int limit,
258                                                         const char *info) {
259   if ((index >= 0) && (index < limit)) {
260     return true;
261   } else {
262     SAFTM_LOG(ERROR) << info << " index " << index << " outside range [0, "
263                      << limit << ")";
264     return false;
265   }
266 }
267 
SafeGetNumInputChunks() const268 int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumInputChunks() const {
269   const auto *input_chunks = network_->input_chunks();
270   if (input_chunks == nullptr) {
271     SAFTM_LOG(ERROR) << "nullptr input_chunks";
272     return 0;
273   }
274   return input_chunks->size();
275 }
276 
277 const saft_fbs::InputChunk *
SafeGetInputChunk(int i) const278 EmbeddingNetworkParamsFromFlatbuffer::SafeGetInputChunk(int i) const {
279   if (!InRangeIndex(i, SafeGetNumInputChunks(), "input chunks")) {
280     return nullptr;
281   }
282   const auto *input_chunks = network_->input_chunks();
283   if (input_chunks == nullptr) {
284     // Execution should not reach this point, due to how SafeGetNumInputChunks()
285     // is implemented.  Still, just to be sure:
286     SAFTM_LOG(ERROR) << "nullptr input_chunks";
287     return nullptr;
288   }
289   const saft_fbs::InputChunk *input_chunk = input_chunks->Get(i);
290   if (input_chunk == nullptr) {
291     SAFTM_LOG(ERROR) << "nullptr input chunk #" << i;
292   }
293   return input_chunk;
294 }
295 
296 const saft_fbs::Matrix *
SafeGetEmbeddingMatrix(int i) const297 EmbeddingNetworkParamsFromFlatbuffer::SafeGetEmbeddingMatrix(int i) const {
298   const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
299   if (input_chunk == nullptr) return nullptr;
300   const saft_fbs::Matrix *matrix = input_chunk->embedding();
301   if (matrix == nullptr) {
302     SAFTM_LOG(ERROR) << "nullptr embeding matrix #" << i;
303   }
304   return matrix;
305 }
306 
SafeGetNumLayers() const307 int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumLayers() const {
308   const auto *layers = network_->layers();
309   if (layers == nullptr) {
310     SAFTM_LOG(ERROR) << "nullptr layers";
311     return 0;
312   }
313   return layers->size();
314 }
315 
SafeGetLayer(int i) const316 const saft_fbs::NeuralLayer *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayer(
317     int i) const {
318   if (!InRangeIndex(i, SafeGetNumLayers(), "layer")) {
319     return nullptr;
320   }
321   const auto *layers = network_->layers();
322   if (layers == nullptr) {
323     // Execution should not reach this point, due to how SafeGetNumLayers()
324     // is implemented.  Still, just to be sure:
325     SAFTM_LOG(ERROR) << "nullptr layers";
326     return nullptr;
327   }
328   const saft_fbs::NeuralLayer *layer = layers->Get(i);
329   if (layer == nullptr) {
330     SAFTM_LOG(ERROR) << "nullptr layer #" << i;
331   }
332   return layer;
333 }
334 
335 const saft_fbs::Matrix *
SafeGetLayerWeights(int i) const336 EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerWeights(int i) const {
337   const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
338   if (layer == nullptr) return nullptr;
339   const saft_fbs::Matrix *weights = layer->weights();
340   if (weights == nullptr) {
341     SAFTM_LOG(ERROR) << "nullptr weights for layer #" << i;
342   }
343   return weights;
344 }
345 
SafeGetLayerBias(int i) const346 const saft_fbs::Matrix *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerBias(
347     int i) const {
348   const saft_fbs::NeuralLayer *layer = SafeGetLayer(i);
349   if (layer == nullptr) return nullptr;
350   const saft_fbs::Matrix *bias = layer->bias();
351   if (bias == nullptr) {
352     SAFTM_LOG(ERROR) << "nullptr bias for layer #" << i;
353   }
354   return bias;
355 }
356 
357 // static
SafeGetValues(const saft_fbs::Matrix * matrix)358 const float *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValues(
359     const saft_fbs::Matrix *matrix) {
360   if (matrix == nullptr) return nullptr;
361   const flatbuffers::Vector<float> *values = matrix->values();
362   if (values == nullptr) {
363     SAFTM_LOG(ERROR) << "nullptr values";
364   }
365   return values->data();
366 }
367 
368 // static
SafeGetQuantizedValues(const saft_fbs::Matrix * matrix)369 const uint8_t *EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizedValues(
370     const saft_fbs::Matrix *matrix) {
371   if (matrix == nullptr) return nullptr;
372   const flatbuffers::Vector<uint8_t> *quantized_values =
373       matrix->quantized_values();
374   if (quantized_values == nullptr) {
375     SAFTM_LOG(ERROR) << "nullptr quantized_values";
376   }
377   return quantized_values->data();
378 }
379 
380 // static
SafeGetScales(const saft_fbs::Matrix * matrix)381 const float16 *EmbeddingNetworkParamsFromFlatbuffer::SafeGetScales(
382     const saft_fbs::Matrix *matrix) {
383   if (matrix == nullptr) return nullptr;
384   const flatbuffers::Vector<uint16_t> *scales = matrix->scales();
385   if (scales == nullptr) {
386     SAFTM_LOG(ERROR) << "nullptr scales";
387     return nullptr;
388   }
389   return scales->data();
390 }
391 
392 const saft_fbs::NeuralLayer *
SafeGetSoftmaxLayer() const393 EmbeddingNetworkParamsFromFlatbuffer::SafeGetSoftmaxLayer() const {
394   int num_layers = SafeGetNumLayers();
395   if (num_layers <= 0) {
396     SAFTM_LOG(ERROR) << "No softmax layer";
397     return nullptr;
398   }
399   return SafeGetLayer(num_layers - 1);
400 }
401 
SafeGetQuantizationType(const saft_fbs::Matrix * matrix) const402 QuantizationType EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizationType(
403     const saft_fbs::Matrix *matrix) const {
404   if (matrix == nullptr) {
405     return QuantizationType::NONE;
406   }
407   saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
408 
409   // Conversion from nlp_saft::saft_fbs::QuantizationType to
410   // nlp_saft::QuantizationType (due to legacy reasons, we have both).
411   switch (quantization_type) {
412     case saft_fbs::QuantizationType_NONE:
413       return QuantizationType::NONE;
414     case saft_fbs::QuantizationType_UINT8:
415       return QuantizationType::UINT8;
416     case saft_fbs::QuantizationType_UINT4:
417       return QuantizationType::UINT4;
418     case saft_fbs::QuantizationType_FLOAT16:
419       return QuantizationType::FLOAT16;
420     default:
421       SAFTM_LOG(ERROR) << "Unsupported quantization type "
422                        << static_cast<int>(quantization_type);
423       return QuantizationType::NONE;
424   }
425 }
426 
SafeGetValuesOfMatrix(const saft_fbs::Matrix * matrix) const427 const void *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValuesOfMatrix(
428     const saft_fbs::Matrix *matrix) const {
429   if (matrix == nullptr) {
430     return nullptr;
431   }
432   saft_fbs::QuantizationType quantization_type = matrix->quantization_type();
433   switch (quantization_type) {
434     case saft_fbs::QuantizationType_NONE:
435       return SafeGetValues(matrix);
436     case saft_fbs::QuantizationType_UINT8:
437       SAFTM_FALLTHROUGH_INTENDED;
438     case saft_fbs::QuantizationType_UINT4:
439       SAFTM_FALLTHROUGH_INTENDED;
440     case saft_fbs::QuantizationType_FLOAT16:
441       return SafeGetQuantizedValues(matrix);
442     default:
443       SAFTM_LOG(ERROR) << "Unsupported quantization type "
444                        << static_cast<int>(quantization_type);
445       return nullptr;
446   }
447 }
448 
449 }  // namespace mobile
450 }  // namespace nlp_saft
451