1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_ 18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_ 19 20 #include <algorithm> 21 #include <string> 22 23 #include "common/float16.h" 24 #include "common/task-context.h" 25 #include "common/task-spec.pb.h" 26 #include "util/base/logging.h" 27 28 namespace libtextclassifier { 29 namespace nlp_core { 30 31 enum class QuantizationType { NONE = 0, UINT8 }; 32 33 // API for accessing parameters for a feed-forward neural network with 34 // embeddings. 35 // 36 // Note: this API is closely related to embedding-network.proto. The reason we 37 // have a separate API is that the proto may not be the only way of packaging 38 // these parameters. 39 class EmbeddingNetworkParams { 40 public: ~EmbeddingNetworkParams()41 virtual ~EmbeddingNetworkParams() {} 42 43 // **** High-level API. 44 45 // Simple representation of a matrix. This small struct that doesn't own any 46 // resource intentionally supports copy / assign, to simplify our APIs. 47 struct Matrix { 48 // Number of rows. 49 int rows; 50 51 // Number of columns. 52 int cols; 53 54 QuantizationType quant_type; 55 56 // Pointer to matrix elements, in row-major order 57 // (https://en.wikipedia.org/wiki/Row-major_order) Not owned. 58 const void *elements; 59 60 // Quantization scales: one scale for each row. 61 const float16 *quant_scales; 62 }; 63 64 // Returns number of embedding spaces. GetNumEmbeddingSpaces()65 int GetNumEmbeddingSpaces() const { 66 if (embeddings_size() != embedding_num_features_size()) { 67 TC_LOG(ERROR) << "Embedding spaces mismatch " << embeddings_size() 68 << " != " << embedding_num_features_size(); 69 } 70 return std::max(0, 71 std::min(embeddings_size(), embedding_num_features_size())); 72 } 73 74 // Returns embedding matrix for the i-th embedding space. 75 // 76 // NOTE: i must be in [0, GetNumEmbeddingSpaces()). Undefined behavior 77 // otherwise. GetEmbeddingMatrix(int i)78 Matrix GetEmbeddingMatrix(int i) const { 79 TC_DCHECK(InRange(i, embeddings_size())); 80 Matrix matrix; 81 matrix.rows = embeddings_num_rows(i); 82 matrix.cols = embeddings_num_cols(i); 83 matrix.elements = embeddings_weights(i); 84 matrix.quant_type = embeddings_quant_type(i); 85 matrix.quant_scales = embeddings_quant_scales(i); 86 return matrix; 87 } 88 89 // Returns number of features in i-th embedding space. 90 // 91 // NOTE: i must be in [0, GetNumEmbeddingSpaces()). Undefined behavior 92 // otherwise. GetNumFeaturesInEmbeddingSpace(int i)93 int GetNumFeaturesInEmbeddingSpace(int i) const { 94 TC_DCHECK(InRange(i, embedding_num_features_size())); 95 return std::max(0, embedding_num_features(i)); 96 } 97 98 // Returns number of hidden layers in the neural network. Each such layer has 99 // weight matrix and a bias vector (a matrix with one column). GetNumHiddenLayers()100 int GetNumHiddenLayers() const { 101 if (hidden_size() != hidden_bias_size()) { 102 TC_LOG(ERROR) << "Hidden layer mismatch " << hidden_size() 103 << " != " << hidden_bias_size(); 104 } 105 return std::max(0, std::min(hidden_size(), hidden_bias_size())); 106 } 107 108 // Returns weight matrix for i-th hidden layer. 109 // 110 // NOTE: i must be in [0, GetNumHiddenLayers()). Undefined behavior 111 // otherwise. GetHiddenLayerMatrix(int i)112 Matrix GetHiddenLayerMatrix(int i) const { 113 TC_DCHECK(InRange(i, hidden_size())); 114 Matrix matrix; 115 matrix.rows = hidden_num_rows(i); 116 matrix.cols = hidden_num_cols(i); 117 118 // Quantization not supported here. 119 matrix.quant_type = QuantizationType::NONE; 120 matrix.elements = hidden_weights(i); 121 return matrix; 122 } 123 124 // Returns bias matrix for i-th hidden layer. Technically a Matrix, but we 125 // expect it to be a vector (i.e., num cols is 1). 126 // 127 // NOTE: i must be in [0, GetNumHiddenLayers()). Undefined behavior 128 // otherwise. GetHiddenLayerBias(int i)129 Matrix GetHiddenLayerBias(int i) const { 130 TC_DCHECK(InRange(i, hidden_bias_size())); 131 Matrix matrix; 132 matrix.rows = hidden_bias_num_rows(i); 133 matrix.cols = hidden_bias_num_cols(i); 134 135 // Quantization not supported here. 136 matrix.quant_type = QuantizationType::NONE; 137 matrix.elements = hidden_bias_weights(i); 138 return matrix; 139 } 140 141 // Returns true if a softmax layer exists. HasSoftmaxLayer()142 bool HasSoftmaxLayer() const { 143 if (softmax_size() != softmax_bias_size()) { 144 TC_LOG(ERROR) << "Softmax layer mismatch " << softmax_size() 145 << " != " << softmax_bias_size(); 146 } 147 return (softmax_size() == 1) && (softmax_bias_size() == 1); 148 } 149 150 // Returns weight matrix for the softmax layer. 151 // 152 // NOTE: Should be called only if HasSoftmaxLayer() is true. Undefined 153 // behavior otherwise. GetSoftmaxMatrix()154 Matrix GetSoftmaxMatrix() const { 155 TC_DCHECK(softmax_size() == 1); 156 Matrix matrix; 157 matrix.rows = softmax_num_rows(0); 158 matrix.cols = softmax_num_cols(0); 159 160 // Quantization not supported here. 161 matrix.quant_type = QuantizationType::NONE; 162 matrix.elements = softmax_weights(0); 163 return matrix; 164 } 165 166 // Returns bias for the softmax layer. Technically a Matrix, but we expect it 167 // to be a row/column vector (i.e., num cols is 1). 168 // 169 // NOTE: Should be called only if HasSoftmaxLayer() is true. Undefined 170 // behavior otherwise. GetSoftmaxBias()171 Matrix GetSoftmaxBias() const { 172 TC_DCHECK(softmax_bias_size() == 1); 173 Matrix matrix; 174 matrix.rows = softmax_bias_num_rows(0); 175 matrix.cols = softmax_bias_num_cols(0); 176 177 // Quantization not supported here. 178 matrix.quant_type = QuantizationType::NONE; 179 matrix.elements = softmax_bias_weights(0); 180 return matrix; 181 } 182 183 // Updates the EmbeddingNetwork-related parameters from task_context. Returns 184 // true on success, false on error. UpdateTaskContextParameters(TaskContext * task_context)185 virtual bool UpdateTaskContextParameters(TaskContext *task_context) { 186 const TaskSpec *task_spec = GetTaskSpec(); 187 if (task_spec == nullptr) { 188 TC_LOG(ERROR) << "Unable to get TaskSpec"; 189 return false; 190 } 191 for (const TaskSpec::Parameter ¶meter : task_spec->parameter()) { 192 task_context->SetParameter(parameter.name(), parameter.value()); 193 } 194 return true; 195 } 196 197 // Returns a pointer to a TaskSpec with the EmbeddingNetwork-related 198 // parameters. Returns nullptr in case of problems. Ownership with the 199 // returned pointer is *not* transfered to the caller. GetTaskSpec()200 virtual const TaskSpec *GetTaskSpec() { 201 TC_LOG(ERROR) << "Not implemented"; 202 return nullptr; 203 } 204 205 protected: 206 // **** Low-level API. 207 // 208 // * Most low-level API methods are documented by giving an equivalent 209 // function call on proto, the original proto (of type 210 // EmbeddingNetworkProto) which was used to generate the C++ code. 211 // 212 // * To simplify our generation code, optional proto fields of message type 213 // are treated as repeated fields with 0 or 1 instances. As such, we have 214 // *_size() methods for such optional fields: they return 0 or 1. 215 // 216 // * "transpose(M)" denotes the transpose of a matrix M. 217 // 218 // * Behavior is undefined when trying to retrieve a piece of data that does 219 // not exist: e.g., embeddings_num_rows(5) if embeddings_size() == 2. 220 221 // ** Access methods for repeated MatrixParams embeddings. 222 // 223 // Returns proto.embeddings_size(). 224 virtual int embeddings_size() const = 0; 225 226 // Returns number of rows of transpose(proto.embeddings(i)). 227 virtual int embeddings_num_rows(int i) const = 0; 228 229 // Returns number of columns of transpose(proto.embeddings(i)). 230 virtual int embeddings_num_cols(int i) const = 0; 231 232 // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major 233 // order. NOTE: for unquantized embeddings, this returns a pointer to float; 234 // for quantized embeddings, this returns a pointer to uint8. 235 virtual const void *embeddings_weights(int i) const = 0; 236 embeddings_quant_type(int i)237 virtual QuantizationType embeddings_quant_type(int i) const { 238 return QuantizationType::NONE; 239 } 240 embeddings_quant_scales(int i)241 virtual const float16 *embeddings_quant_scales(int i) const { 242 return nullptr; 243 } 244 245 // ** Access methods for repeated MatrixParams hidden. 246 // 247 // Returns embedding_network_proto.hidden_size(). 248 virtual int hidden_size() const = 0; 249 250 // Returns embedding_network_proto.hidden(i).rows(). 251 virtual int hidden_num_rows(int i) const = 0; 252 253 // Returns embedding_network_proto.hidden(i).rows(). 254 virtual int hidden_num_cols(int i) const = 0; 255 256 // Returns pointer to beginning of array of floats with all values from 257 // embedding_network_proto.hidden(i). 258 virtual const void *hidden_weights(int i) const = 0; 259 260 // ** Access methods for repeated MatrixParams hidden_bias. 261 // 262 // Returns proto.hidden_bias_size(). 263 virtual int hidden_bias_size() const = 0; 264 265 // Returns number of rows of proto.hidden_bias(i). 266 virtual int hidden_bias_num_rows(int i) const = 0; 267 268 // Returns number of columns of proto.hidden_bias(i). 269 virtual int hidden_bias_num_cols(int i) const = 0; 270 271 // Returns pointer to elements of proto.hidden_bias(i), in row-major order. 272 virtual const void *hidden_bias_weights(int i) const = 0; 273 274 // ** Access methods for optional MatrixParams softmax. 275 // 276 // Returns 1 if proto has optional field softmax, 0 otherwise. 277 virtual int softmax_size() const = 0; 278 279 // Returns number of rows of transpose(proto.softmax()). 280 virtual int softmax_num_rows(int i) const = 0; 281 282 // Returns number of columns of transpose(proto.softmax()). 283 virtual int softmax_num_cols(int i) const = 0; 284 285 // Returns pointer to elements of transpose(proto.softmax()), in row-major 286 // order. 287 virtual const void *softmax_weights(int i) const = 0; 288 289 // ** Access methods for optional MatrixParams softmax_bias. 290 // 291 // Returns 1 if proto has optional field softmax_bias, 0 otherwise. 292 virtual int softmax_bias_size() const = 0; 293 294 // Returns number of rows of proto.softmax_bias(). 295 virtual int softmax_bias_num_rows(int i) const = 0; 296 297 // Returns number of columns of proto.softmax_bias(). 298 virtual int softmax_bias_num_cols(int i) const = 0; 299 300 // Returns pointer to elements of proto.softmax_bias(), in row-major order. 301 virtual const void *softmax_bias_weights(int i) const = 0; 302 303 // ** Access methods for repeated int32 embedding_num_features. 304 // 305 // Returns proto.embedding_num_features_size(). 306 virtual int embedding_num_features_size() const = 0; 307 308 // Returns proto.embedding_num_features(i). 309 virtual int embedding_num_features(int i) const = 0; 310 311 // Returns true if and only if index is in range [0, size). Log an error 312 // message otherwise. InRange(int index,int size)313 static bool InRange(int index, int size) { 314 if ((index < 0) || (index >= size)) { 315 TC_LOG(ERROR) << "Index " << index << " outside [0, " << size << ")"; 316 return false; 317 } 318 return true; 319 } 320 }; // class EmbeddingNetworkParams 321 322 } // namespace nlp_core 323 } // namespace libtextclassifier 324 325 #endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_ 326