1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_ 18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_ 19 20 #include <string> 21 22 #include "lang_id/common/fel/task-context.h" 23 #include "lang_id/common/lite_base/float16.h" 24 #include "lang_id/common/lite_base/logging.h" 25 26 namespace libtextclassifier3 { 27 28 enum class QuantizationType { 29 NONE = 0, 30 31 // Quantization to 8 bit unsigned ints. 32 UINT8, 33 34 // Quantization to 4 bit unsigned ints. 35 UINT4, 36 37 // Quantization to 16 bit floats, the type defined in 38 // lang_id/common/float16.h 39 FLOAT16, 40 41 // NOTE: for backward compatibility, if you add a new value to this enum, add 42 // it *at the end*, such that you do not change the integer values of the 43 // existing enum values. 44 }; 45 46 // Converts "UINT8" -> QuantizationType::UINT8, and so on. 47 QuantizationType ParseQuantizationType(const std::string &s); 48 49 // API for accessing parameters for a feed-forward neural network with 50 // embeddings. 51 // 52 // 53 // In fact, we provide two APIs: a high-level (and highly-recommented) API, with 54 // methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a 55 // low-level API, using C-style names (e.g., softmax_num_cols()). 56 // 57 // Note: the API below is meant to allow the inference code (the class 58 // libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need 59 // for transposing any matrix (which would require extra overhead on mobile 60 // devices). Hence, as indicated by the comments for the API methods, some of 61 // the matrices below are the transposes of the corresponding matrices from the 62 // original proto. 63 class EmbeddingNetworkParams { 64 public: ~EmbeddingNetworkParams()65 virtual ~EmbeddingNetworkParams() {} 66 67 // Returns true if these params are valid. False otherwise (e.g., if the 68 // underlying data is corrupted). If is_valid() returns false, clients should 69 // not call any other method on that instance of EmbeddingNetworkParams. If 70 // is_valid() returns true, then calls to the API methods below should not 71 // crash *if they are called with index parameters in bounds*. E.g., if 72 // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i) 73 // should not crash. 74 virtual bool is_valid() const = 0; 75 76 // **** High-level API. 77 78 // Simple representation of a matrix. This small struct that doesn't own any 79 // resource intentionally supports copy / assign, to simplify our APIs. 80 struct Matrix { 81 // Number of rows. 82 int rows = 0; 83 84 // Number of columns. 85 int cols = 0; 86 87 QuantizationType quant_type = QuantizationType::NONE; 88 89 // Pointer to matrix elements, in row-major order 90 // (https://en.wikipedia.org/wiki/Row-major_order) Not owned. 91 const void *elements = nullptr; 92 93 // Quantization scales: one scale for each row. 94 const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr; 95 }; 96 97 // Returns i-th embedding matrix. Crashes on out of bounds indices. 98 // 99 // This is the transpose of the corresponding matrix from the original proto. GetEmbeddingMatrix(int i)100 Matrix GetEmbeddingMatrix(int i) const { 101 CheckIndex(i, embeddings_size(), "embedding matrix"); 102 Matrix matrix; 103 matrix.rows = embeddings_num_rows(i); 104 matrix.cols = embeddings_num_cols(i); 105 matrix.elements = embeddings_weights(i); 106 matrix.quant_type = embeddings_quant_type(i); 107 matrix.quant_scales = embeddings_quant_scales(i); 108 return matrix; 109 } 110 111 // Returns weight matrix for i-th hidden layer. Crashes on out of bounds 112 // indices. 113 // 114 // This is the transpose of the corresponding matrix from the original proto. GetHiddenLayerMatrix(int i)115 Matrix GetHiddenLayerMatrix(int i) const { 116 CheckIndex(i, hidden_size(), "hidden layer"); 117 Matrix matrix; 118 matrix.rows = hidden_num_rows(i); 119 matrix.cols = hidden_num_cols(i); 120 121 // Quantization not supported here. 122 matrix.quant_type = hidden_weights_quant_type(i); 123 matrix.elements = hidden_weights(i); 124 return matrix; 125 } 126 127 // Returns bias for i-th hidden layer. Technically a Matrix, but we expect it 128 // to be a row/column vector (i.e., num rows or num cols is 1). However, we 129 // don't CHECK for that: we just provide access to underlying data. Crashes 130 // on out of bounds indices. GetHiddenLayerBias(int i)131 Matrix GetHiddenLayerBias(int i) const { 132 CheckIndex(i, hidden_bias_size(), "hidden layer bias"); 133 Matrix matrix; 134 matrix.rows = hidden_bias_num_rows(i); 135 matrix.cols = hidden_bias_num_cols(i); 136 137 // Quantization not supported here. 138 matrix.quant_type = QuantizationType::NONE; 139 matrix.elements = hidden_bias_weights(i); 140 return matrix; 141 } 142 143 // Returns true if a softmax layer exists. HasSoftmax()144 bool HasSoftmax() const { 145 return softmax_size() == 1; 146 } 147 148 // Returns weight matrix for the softmax layer. Note: should be called only 149 // if HasSoftmax() is true. 150 // 151 // This is the transpose of the corresponding matrix from the original proto. GetSoftmaxMatrix()152 Matrix GetSoftmaxMatrix() const { 153 SAFTM_CHECK(HasSoftmax()) << "No softmax layer."; 154 Matrix matrix; 155 matrix.rows = softmax_num_rows(0); 156 matrix.cols = softmax_num_cols(0); 157 158 // Quantization not supported here. 159 matrix.quant_type = softmax_weights_quant_type(0); 160 matrix.elements = softmax_weights(0); 161 return matrix; 162 } 163 164 // Returns bias for the softmax layer. Technically a Matrix, but we expect it 165 // to be a row/column vector (i.e., num rows or num cols is 1). However, we 166 // don't CHECK for that: we just provide access to underlying data. GetSoftmaxBias()167 Matrix GetSoftmaxBias() const { 168 SAFTM_CHECK(HasSoftmax()) << "No softmax layer."; 169 Matrix matrix; 170 matrix.rows = softmax_bias_num_rows(0); 171 matrix.cols = softmax_bias_num_cols(0); 172 173 // Quantization not supported here. 174 matrix.quant_type = QuantizationType::NONE; 175 matrix.elements = softmax_bias_weights(0); 176 return matrix; 177 } 178 179 // Updates the EmbeddingNetwork-related parameters from task_context. Returns 180 // true on success, false on error. 181 virtual bool UpdateTaskContextParameters( 182 mobile::TaskContext *task_context) = 0; 183 184 // **** Low-level API. 185 // 186 // * Most low-level API methods are documented by giving an equivalent 187 // function call on proto, the original proto (of type 188 // EmbeddingNetworkProto) which was used to generate the C++ code. 189 // 190 // * To simplify our generation code, optional proto fields of message type 191 // are treated as repeated fields with 0 or 1 instances. As such, we have 192 // *_size() methods for such optional fields: they return 0 or 1. 193 // 194 // * "transpose(M)" denotes the transpose of a matrix M. 195 196 // ** Access methods for repeated MatrixParams embeddings. 197 // 198 // Returns proto.embeddings_size(). 199 virtual int embeddings_size() const = 0; 200 201 // Returns number of rows of transpose(proto.embeddings(i)). 202 virtual int embeddings_num_rows(int i) const = 0; 203 204 // Returns number of columns of transpose(proto.embeddings(i)). 205 virtual int embeddings_num_cols(int i) const = 0; 206 207 // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major 208 // order. NOTE: for unquantized embeddings, this returns a pointer to float; 209 // for quantized embeddings, this returns a pointer to uint8. 210 virtual const void *embeddings_weights(int i) const = 0; 211 embeddings_quant_type(int i)212 virtual QuantizationType embeddings_quant_type(int i) const { 213 return QuantizationType::NONE; 214 } 215 embeddings_quant_scales(int i)216 virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales( 217 int i) const { 218 return nullptr; 219 } 220 221 // ** Access methods for repeated MatrixParams hidden. 222 // 223 // Returns embedding_network_proto.hidden_size(). 224 virtual int hidden_size() const = 0; 225 226 // Returns embedding_network_proto.hidden(i).rows(). 227 virtual int hidden_num_rows(int i) const = 0; 228 229 // Returns embedding_network_proto.hidden(i).rows(). 230 virtual int hidden_num_cols(int i) const = 0; 231 232 // Returns quantization mode for the weights of the i-th hidden layer. hidden_weights_quant_type(int i)233 virtual QuantizationType hidden_weights_quant_type(int i) const { 234 return QuantizationType::NONE; 235 } 236 237 // Returns pointer to beginning of array of floats with all values from 238 // embedding_network_proto.hidden(i). 239 virtual const void *hidden_weights(int i) const = 0; 240 241 // ** Access methods for repeated MatrixParams hidden_bias. 242 // 243 // Returns proto.hidden_bias_size(). 244 virtual int hidden_bias_size() const = 0; 245 246 // Returns number of rows of proto.hidden_bias(i). 247 virtual int hidden_bias_num_rows(int i) const = 0; 248 249 // Returns number of columns of proto.hidden_bias(i). 250 virtual int hidden_bias_num_cols(int i) const = 0; 251 252 // Returns pointer to elements of proto.hidden_bias(i), in row-major order. 253 virtual const void *hidden_bias_weights(int i) const = 0; 254 255 // ** Access methods for optional MatrixParams softmax. 256 // 257 // Returns 1 if proto has optional field softmax, 0 otherwise. 258 virtual int softmax_size() const = 0; 259 260 // Returns number of rows of transpose(proto.softmax()). 261 virtual int softmax_num_rows(int i) const = 0; 262 263 // Returns number of columns of transpose(proto.softmax()). 264 virtual int softmax_num_cols(int i) const = 0; 265 266 // Returns quantization mode for the softmax weights. softmax_weights_quant_type(int i)267 virtual QuantizationType softmax_weights_quant_type(int i) const { 268 return QuantizationType::NONE; 269 } 270 271 // Returns pointer to elements of transpose(proto.softmax()), in row-major 272 // order. 273 virtual const void *softmax_weights(int i) const = 0; 274 275 // ** Access methods for optional MatrixParams softmax_bias. 276 // 277 // Returns 1 if proto has optional field softmax_bias, 0 otherwise. 278 virtual int softmax_bias_size() const = 0; 279 280 // Returns number of rows of proto.softmax_bias(). 281 virtual int softmax_bias_num_rows(int i) const = 0; 282 283 // Returns number of columns of proto.softmax_bias(). 284 virtual int softmax_bias_num_cols(int i) const = 0; 285 286 // Returns pointer to elements of proto.softmax_bias(), in row-major order. 287 virtual const void *softmax_bias_weights(int i) const = 0; 288 289 // ** Access methods for repeated int32 embedding_num_features. 290 // 291 // Returns proto.embedding_num_features_size(). 292 virtual int embedding_num_features_size() const = 0; 293 294 // Returns proto.embedding_num_features(i). 295 virtual int embedding_num_features(int i) const = 0; 296 297 // ** Access methods for is_precomputed 298 // 299 // Returns proto.has_is_precomputed(). 300 virtual bool has_is_precomputed() const = 0; 301 302 // Returns proto.is_precomputed(). 303 virtual bool is_precomputed() const = 0; 304 305 protected: CheckIndex(int index,int size,const std::string & description)306 void CheckIndex(int index, int size, const std::string &description) const { 307 SAFTM_CHECK_GE(index, 0) 308 << "Out-of-range index for " << description << ": " << index; 309 SAFTM_CHECK_LT(index, size) 310 << "Out-of-range index for " << description << ": " << index; 311 } 312 }; // class EmbeddingNetworkParams 313 314 } // namespace nlp_saft 315 316 #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_ 317