1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_ 18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 25 #include "lang_id/common/embedding-network-params.h" 26 #include "lang_id/common/flatbuffers/embedding-network_generated.h" 27 #include "lang_id/common/lite_base/float16.h" 28 #include "lang_id/common/lite_base/logging.h" 29 #include "lang_id/common/lite_strings/stringpiece.h" 30 31 namespace libtextclassifier3 { 32 namespace mobile { 33 34 // EmbeddingNetworkParams implementation backed by a flatbuffer. 35 // 36 // For info on our flatbuffer schema, see embedding-network.fbs. 37 class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams { 38 public: 39 // Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the 40 // flatbuffer from |bytes|. 41 // 42 // IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime 43 // of this EmbeddingNetworkParamsFromFlatbuffer instance. To avoid overhead, 44 // this constructor does not copy |bytes|. 45 // 46 // IMPORTANT #2: immediately after this constructor returns, we suggest you 47 // call is_valid() on the newly-constructed object and do not call any other 48 // method if the answer is negative (false). 49 explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes); 50 UpdateTaskContextParameters(mobile::TaskContext * task_context)51 bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override { 52 // This class does not provide access to the overall TaskContext. It 53 // provides only parameters for the Neurosis neural network. 54 SAFTM_LOG(DFATAL) << "Not supported"; 55 return false; 56 } 57 is_valid()58 bool is_valid() const override { return valid_; } 59 embeddings_size()60 int embeddings_size() const override { return SafeGetNumInputChunks(); } 61 embeddings_num_rows(int i)62 int embeddings_num_rows(int i) const override { 63 const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); 64 return SafeGetNumRows(matrix); 65 } 66 embeddings_num_cols(int i)67 int embeddings_num_cols(int i) const override { 68 const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); 69 return SafeGetNumCols(matrix); 70 } 71 embeddings_weights(int i)72 const void *embeddings_weights(int i) const override { 73 const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); 74 return SafeGetValuesOfMatrix(matrix); 75 } 76 embeddings_quant_type(int i)77 QuantizationType embeddings_quant_type(int i) const override { 78 const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); 79 return SafeGetQuantizationType(matrix); 80 } 81 embeddings_quant_scales(int i)82 const float16 *embeddings_quant_scales(int i) const override { 83 const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); 84 return SafeGetScales(matrix); 85 } 86 hidden_size()87 int hidden_size() const override { 88 // -1 because last layer is always the softmax layer. 89 return std::max(SafeGetNumLayers() - 1, 0); 90 } 91 hidden_num_rows(int i)92 int hidden_num_rows(int i) const override { 93 const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); 94 return SafeGetNumRows(weights); 95 } 96 hidden_num_cols(int i)97 int hidden_num_cols(int i) const override { 98 const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); 99 return SafeGetNumCols(weights); 100 } 101 hidden_weights_quant_type(int i)102 QuantizationType hidden_weights_quant_type(int i) const override { 103 const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); 104 return SafeGetQuantizationType(weights); 105 } 106 hidden_weights(int i)107 const void *hidden_weights(int i) const override { 108 const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); 109 return SafeGetValuesOfMatrix(weights); 110 } 111 hidden_bias_size()112 int hidden_bias_size() const override { return hidden_size(); } 113 hidden_bias_num_rows(int i)114 int hidden_bias_num_rows(int i) const override { 115 const saft_fbs::Matrix *bias = SafeGetLayerBias(i); 116 return SafeGetNumRows(bias); 117 } 118 hidden_bias_num_cols(int i)119 int hidden_bias_num_cols(int i) const override { 120 const saft_fbs::Matrix *bias = SafeGetLayerBias(i); 121 return SafeGetNumCols(bias); 122 } 123 hidden_bias_weights(int i)124 const void *hidden_bias_weights(int i) const override { 125 const saft_fbs::Matrix *bias = SafeGetLayerBias(i); 126 return SafeGetValues(bias); 127 } 128 softmax_size()129 int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; } 130 softmax_num_rows(int i)131 int softmax_num_rows(int i) const override { 132 const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); 133 return SafeGetNumRows(weights); 134 } 135 softmax_num_cols(int i)136 int softmax_num_cols(int i) const override { 137 const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); 138 return SafeGetNumCols(weights); 139 } 140 softmax_weights_quant_type(int i)141 QuantizationType softmax_weights_quant_type(int i) const override { 142 const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); 143 return SafeGetQuantizationType(weights); 144 } 145 softmax_weights(int i)146 const void *softmax_weights(int i) const override { 147 const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); 148 return SafeGetValuesOfMatrix(weights); 149 } 150 softmax_bias_size()151 int softmax_bias_size() const override { return softmax_size(); } 152 softmax_bias_num_rows(int i)153 int softmax_bias_num_rows(int i) const override { 154 const saft_fbs::Matrix *bias = SafeGetSoftmaxBias(); 155 return SafeGetNumRows(bias); 156 } 157 softmax_bias_num_cols(int i)158 int softmax_bias_num_cols(int i) const override { 159 const saft_fbs::Matrix *bias = SafeGetSoftmaxBias(); 160 return SafeGetNumCols(bias); 161 } 162 softmax_bias_weights(int i)163 const void *softmax_bias_weights(int i) const override { 164 const saft_fbs::Matrix *bias = SafeGetSoftmaxBias(); 165 return SafeGetValues(bias); 166 } 167 embedding_num_features_size()168 int embedding_num_features_size() const override { 169 return SafeGetNumInputChunks(); 170 } 171 embedding_num_features(int i)172 int embedding_num_features(int i) const override { 173 if (!InRangeIndex(i, embedding_num_features_size(), 174 "embedding num features")) { 175 return 0; 176 } 177 const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i); 178 if (input_chunk == nullptr) { 179 return 0; 180 } 181 return input_chunk->num_features(); 182 } 183 has_is_precomputed()184 bool has_is_precomputed() const override { return false; } is_precomputed()185 bool is_precomputed() const override { return false; } 186 187 private: 188 // Returns true if and only if index is in [0, limit). info should be a 189 // pointer to a zero-terminated array of chars (ideally a literal string, 190 // e.g. "layer") indicating what the index refers to; info is used to make log 191 // messages more informative. 192 static bool InRangeIndex(int index, int limit, const char *info); 193 194 // Returns network_->input_chunks()->size(), if all dereferences are safe 195 // (i.e., no nullptr); otherwise, returns 0. 196 int SafeGetNumInputChunks() const; 197 198 // Returns network_->input_chunks()->Get(i), if all dereferences are safe 199 // (i.e., no nullptr) otherwise, returns nullptr. 200 const saft_fbs::InputChunk *SafeGetInputChunk(int i) const; 201 202 // Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences 203 // are safe (i.e., no nullptr); otherwise, returns nullptr. 204 const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const; 205 206 // Returns network_->layers()->size(), if all dereferences are safe (i.e., no 207 // nullptr); otherwise, returns 0. 208 int SafeGetNumLayers() const; 209 210 // Returns network_->layers()->Get(i), if all dereferences are safe 211 // (i.e., no nullptr); otherwise, returns nullptr. 212 const saft_fbs::NeuralLayer *SafeGetLayer(int i) const; 213 214 // Returns network_->layers()->Get(i)->weights(), if all dereferences are safe 215 // (i.e., no nullptr); otherwise, returns nullptr. 216 const saft_fbs::Matrix *SafeGetLayerWeights(int i) const; 217 218 // Returns network_->layers()->Get(i)->bias(), if all dereferences are safe 219 // (i.e., no nullptr); otherwise, returns nullptr. 220 const saft_fbs::Matrix *SafeGetLayerBias(int i) const; 221 SafeGetNumRows(const saft_fbs::Matrix * matrix)222 static int SafeGetNumRows(const saft_fbs::Matrix *matrix) { 223 return (matrix == nullptr) ? 0 : matrix->rows(); 224 } 225 SafeGetNumCols(const saft_fbs::Matrix * matrix)226 static int SafeGetNumCols(const saft_fbs::Matrix *matrix) { 227 return (matrix == nullptr) ? 0 : matrix->cols(); 228 } 229 230 // Returns matrix->values()->data() if all dereferences are safe (i.e., no 231 // nullptr); otherwise, returns nullptr. 232 static const float *SafeGetValues(const saft_fbs::Matrix *matrix); 233 234 // Returns matrix->quantized_values()->data() if all dereferences are safe 235 // (i.e., no nullptr); otherwise, returns nullptr. 236 static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix); 237 238 // Returns matrix->scales()->data() if all dereferences are safe (i.e., no 239 // nullptr); otherwise, returns nullptr. 240 static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix); 241 242 // Returns network_->layers()->Get(last_index) with last_index = 243 // SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and 244 // there exists at least one layer; otherwise, returns nullptr. 245 const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const; 246 SafeGetSoftmaxWeights()247 const saft_fbs::Matrix *SafeGetSoftmaxWeights() const { 248 const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer(); 249 return (layer == nullptr) ? nullptr : layer->weights(); 250 } 251 SafeGetSoftmaxBias()252 const saft_fbs::Matrix *SafeGetSoftmaxBias() const { 253 const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer(); 254 return (layer == nullptr) ? nullptr : layer->bias(); 255 } 256 257 // Returns the quantization type for |matrix|. Returns NONE in case of 258 // problems (e.g., matrix is nullptr or unknown quantization type). 259 QuantizationType SafeGetQuantizationType( 260 const saft_fbs::Matrix *matrix) const; 261 262 // Returns a pointer to the values (float, uint8, or float16, depending on 263 // quantization) from |matrix|, in row-major order. Returns nullptr in case 264 // of a problem. 265 const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const; 266 267 // Performs some validity checks. E.g., check that dimensions of the network 268 // layers match. Also checks that all pointers we return are inside the 269 // |bytes| passed to the constructor, such that client that reads from those 270 // pointers will not run into troubles. 271 bool ValidityChecking(StringPiece bytes) const; 272 273 // True if these params are valid. May be false if the original proto was 274 // corrupted. We prefer to set this to false to CHECK-failing. 275 bool valid_ = false; 276 277 // EmbeddingNetwork flatbuffer from the bytes passed as parameter to the 278 // constructor; see constructor doc. 279 const saft_fbs::EmbeddingNetwork *network_ = nullptr; 280 }; 281 282 } // namespace mobile 283 } // namespace nlp_saft 284 285 #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_ 286