• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
19 
20 #include <string>
21 
22 #include "lang_id/common/fel/task-context.h"
23 #include "lang_id/common/lite_base/float16.h"
24 #include "lang_id/common/lite_base/logging.h"
25 
26 namespace libtextclassifier3 {
27 
28 enum class QuantizationType {
29   NONE = 0,
30 
31   // Quantization to 8 bit unsigned ints.
32   UINT8,
33 
34   // Quantization to 4 bit unsigned ints.
35   UINT4,
36 
37   // Quantization to 16 bit floats, the type defined in
38   // lang_id/common/float16.h
39   FLOAT16,
40 
41   // NOTE: for backward compatibility, if you add a new value to this enum, add
42   // it *at the end*, such that you do not change the integer values of the
43   // existing enum values.
44 };
45 
46 // Converts "UINT8" -> QuantizationType::UINT8, and so on.
47 QuantizationType ParseQuantizationType(const std::string &s);
48 
49 // API for accessing parameters for a feed-forward neural network with
50 // embeddings.
51 //
52 //
53 // In fact, we provide two APIs: a high-level (and highly-recommented) API, with
54 // methods named using the BigCamel notation (e.g., GetEmbeddingMatrix()) and a
55 // low-level API, using C-style names (e.g., softmax_num_cols()).
56 //
57 // Note: the API below is meant to allow the inference code (the class
58 // libtextclassifier3::mobile::EmbeddingNetwork) to use the data directly, with no need
59 // for transposing any matrix (which would require extra overhead on mobile
60 // devices).  Hence, as indicated by the comments for the API methods, some of
61 // the matrices below are the transposes of the corresponding matrices from the
62 // original proto.
63 class EmbeddingNetworkParams {
64  public:
~EmbeddingNetworkParams()65   virtual ~EmbeddingNetworkParams() {}
66 
67   // Returns true if these params are valid.  False otherwise (e.g., if the
68   // underlying data is corrupted).  If is_valid() returns false, clients should
69   // not call any other method on that instance of EmbeddingNetworkParams.  If
70   // is_valid() returns true, then calls to the API methods below should not
71   // crash *if they are called with index parameters in bounds*.  E.g., if
72   // is_valid() and 0 <= i < embeddings_size(), then GetEmbeddingMatrix(i)
73   // should not crash.
74   virtual bool is_valid() const = 0;
75 
76   // **** High-level API.
77 
78   // Simple representation of a matrix.  This small struct that doesn't own any
79   // resource intentionally supports copy / assign, to simplify our APIs.
80   struct Matrix {
81     // Number of rows.
82     int rows = 0;
83 
84     // Number of columns.
85     int cols = 0;
86 
87     QuantizationType quant_type = QuantizationType::NONE;
88 
89     // Pointer to matrix elements, in row-major order
90     // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
91     const void *elements = nullptr;
92 
93     // Quantization scales: one scale for each row.
94     const ::libtextclassifier3::mobile::float16 *quant_scales = nullptr;
95   };
96 
97   // Returns i-th embedding matrix.  Crashes on out of bounds indices.
98   //
99   // This is the transpose of the corresponding matrix from the original proto.
GetEmbeddingMatrix(int i)100   Matrix GetEmbeddingMatrix(int i) const {
101     CheckIndex(i, embeddings_size(), "embedding matrix");
102     Matrix matrix;
103     matrix.rows = embeddings_num_rows(i);
104     matrix.cols = embeddings_num_cols(i);
105     matrix.elements = embeddings_weights(i);
106     matrix.quant_type = embeddings_quant_type(i);
107     matrix.quant_scales = embeddings_quant_scales(i);
108     return matrix;
109   }
110 
111   // Returns weight matrix for i-th hidden layer.  Crashes on out of bounds
112   // indices.
113   //
114   // This is the transpose of the corresponding matrix from the original proto.
GetHiddenLayerMatrix(int i)115   Matrix GetHiddenLayerMatrix(int i) const {
116     CheckIndex(i, hidden_size(), "hidden layer");
117     Matrix matrix;
118     matrix.rows = hidden_num_rows(i);
119     matrix.cols = hidden_num_cols(i);
120 
121     // Quantization not supported here.
122     matrix.quant_type = hidden_weights_quant_type(i);
123     matrix.elements = hidden_weights(i);
124     return matrix;
125   }
126 
127   // Returns bias for i-th hidden layer.  Technically a Matrix, but we expect it
128   // to be a row/column vector (i.e., num rows or num cols is 1).  However, we
129   // don't CHECK for that: we just provide access to underlying data.  Crashes
130   // on out of bounds indices.
GetHiddenLayerBias(int i)131   Matrix GetHiddenLayerBias(int i) const {
132     CheckIndex(i, hidden_bias_size(), "hidden layer bias");
133     Matrix matrix;
134     matrix.rows = hidden_bias_num_rows(i);
135     matrix.cols = hidden_bias_num_cols(i);
136 
137     // Quantization not supported here.
138     matrix.quant_type = QuantizationType::NONE;
139     matrix.elements = hidden_bias_weights(i);
140     return matrix;
141   }
142 
143   // Returns true if a softmax layer exists.
HasSoftmax()144   bool HasSoftmax() const {
145     return softmax_size() == 1;
146   }
147 
148   // Returns weight matrix for the softmax layer.  Note: should be called only
149   // if HasSoftmax() is true.
150   //
151   // This is the transpose of the corresponding matrix from the original proto.
GetSoftmaxMatrix()152   Matrix GetSoftmaxMatrix() const {
153     SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
154     Matrix matrix;
155     matrix.rows = softmax_num_rows(0);
156     matrix.cols = softmax_num_cols(0);
157 
158     // Quantization not supported here.
159     matrix.quant_type = softmax_weights_quant_type(0);
160     matrix.elements = softmax_weights(0);
161     return matrix;
162   }
163 
164   // Returns bias for the softmax layer.  Technically a Matrix, but we expect it
165   // to be a row/column vector (i.e., num rows or num cols is 1).  However, we
166   // don't CHECK for that: we just provide access to underlying data.
GetSoftmaxBias()167   Matrix GetSoftmaxBias() const {
168     SAFTM_CHECK(HasSoftmax()) << "No softmax layer.";
169     Matrix matrix;
170     matrix.rows = softmax_bias_num_rows(0);
171     matrix.cols = softmax_bias_num_cols(0);
172 
173     // Quantization not supported here.
174     matrix.quant_type = QuantizationType::NONE;
175     matrix.elements = softmax_bias_weights(0);
176     return matrix;
177   }
178 
179   // Updates the EmbeddingNetwork-related parameters from task_context.  Returns
180   // true on success, false on error.
181   virtual bool UpdateTaskContextParameters(
182       mobile::TaskContext *task_context) = 0;
183 
184   // **** Low-level API.
185   //
186   // * Most low-level API methods are documented by giving an equivalent
187   //   function call on proto, the original proto (of type
188   //   EmbeddingNetworkProto) which was used to generate the C++ code.
189   //
190   // * To simplify our generation code, optional proto fields of message type
191   //   are treated as repeated fields with 0 or 1 instances.  As such, we have
192   //   *_size() methods for such optional fields: they return 0 or 1.
193   //
194   // * "transpose(M)" denotes the transpose of a matrix M.
195 
196   // ** Access methods for repeated MatrixParams embeddings.
197   //
198   // Returns proto.embeddings_size().
199   virtual int embeddings_size() const = 0;
200 
201   // Returns number of rows of transpose(proto.embeddings(i)).
202   virtual int embeddings_num_rows(int i) const = 0;
203 
204   // Returns number of columns of transpose(proto.embeddings(i)).
205   virtual int embeddings_num_cols(int i) const = 0;
206 
207   // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
208   // order.  NOTE: for unquantized embeddings, this returns a pointer to float;
209   // for quantized embeddings, this returns a pointer to uint8.
210   virtual const void *embeddings_weights(int i) const = 0;
211 
embeddings_quant_type(int i)212   virtual QuantizationType embeddings_quant_type(int i) const {
213     return QuantizationType::NONE;
214   }
215 
embeddings_quant_scales(int i)216   virtual const ::libtextclassifier3::mobile::float16 *embeddings_quant_scales(
217       int i) const {
218     return nullptr;
219   }
220 
221   // ** Access methods for repeated MatrixParams hidden.
222   //
223   // Returns embedding_network_proto.hidden_size().
224   virtual int hidden_size() const = 0;
225 
226   // Returns embedding_network_proto.hidden(i).rows().
227   virtual int hidden_num_rows(int i) const = 0;
228 
229   // Returns embedding_network_proto.hidden(i).rows().
230   virtual int hidden_num_cols(int i) const = 0;
231 
232   // Returns quantization mode for the weights of the i-th hidden layer.
hidden_weights_quant_type(int i)233   virtual QuantizationType hidden_weights_quant_type(int i) const {
234     return QuantizationType::NONE;
235   }
236 
237   // Returns pointer to beginning of array of floats with all values from
238   // embedding_network_proto.hidden(i).
239   virtual const void *hidden_weights(int i) const = 0;
240 
241   // ** Access methods for repeated MatrixParams hidden_bias.
242   //
243   // Returns proto.hidden_bias_size().
244   virtual int hidden_bias_size() const = 0;
245 
246   // Returns number of rows of proto.hidden_bias(i).
247   virtual int hidden_bias_num_rows(int i) const = 0;
248 
249   // Returns number of columns of proto.hidden_bias(i).
250   virtual int hidden_bias_num_cols(int i) const = 0;
251 
252   // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
253   virtual const void *hidden_bias_weights(int i) const = 0;
254 
255   // ** Access methods for optional MatrixParams softmax.
256   //
257   // Returns 1 if proto has optional field softmax, 0 otherwise.
258   virtual int softmax_size() const = 0;
259 
260   // Returns number of rows of transpose(proto.softmax()).
261   virtual int softmax_num_rows(int i) const = 0;
262 
263   // Returns number of columns of transpose(proto.softmax()).
264   virtual int softmax_num_cols(int i) const = 0;
265 
266   // Returns quantization mode for the softmax weights.
softmax_weights_quant_type(int i)267   virtual QuantizationType softmax_weights_quant_type(int i) const {
268     return QuantizationType::NONE;
269   }
270 
271   // Returns pointer to elements of transpose(proto.softmax()), in row-major
272   // order.
273   virtual const void *softmax_weights(int i) const = 0;
274 
275   // ** Access methods for optional MatrixParams softmax_bias.
276   //
277   // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
278   virtual int softmax_bias_size() const = 0;
279 
280   // Returns number of rows of proto.softmax_bias().
281   virtual int softmax_bias_num_rows(int i) const = 0;
282 
283   // Returns number of columns of proto.softmax_bias().
284   virtual int softmax_bias_num_cols(int i) const = 0;
285 
286   // Returns pointer to elements of proto.softmax_bias(), in row-major order.
287   virtual const void *softmax_bias_weights(int i) const = 0;
288 
289   // ** Access methods for repeated int32 embedding_num_features.
290   //
291   // Returns proto.embedding_num_features_size().
292   virtual int embedding_num_features_size() const = 0;
293 
294   // Returns proto.embedding_num_features(i).
295   virtual int embedding_num_features(int i) const = 0;
296 
297   // ** Access methods for is_precomputed
298   //
299   // Returns proto.has_is_precomputed().
300   virtual bool has_is_precomputed() const = 0;
301 
302   // Returns proto.is_precomputed().
303   virtual bool is_precomputed() const = 0;
304 
305  protected:
CheckIndex(int index,int size,const std::string & description)306   void CheckIndex(int index, int size, const std::string &description) const {
307     SAFTM_CHECK_GE(index, 0)
308         << "Out-of-range index for " << description << ": " << index;
309     SAFTM_CHECK_LT(index, size)
310         << "Out-of-range index for " << description << ": " << index;
311   }
312 };  // class EmbeddingNetworkParams
313 
314 }  // namespace nlp_saft
315 
316 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_PARAMS_H_
317