• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
19 
20 #include <algorithm>
21 #include <string>
22 
23 #include "common/float16.h"
24 #include "common/task-context.h"
25 #include "common/task-spec.pb.h"
26 #include "util/base/logging.h"
27 
28 namespace libtextclassifier {
29 namespace nlp_core {
30 
31 enum class QuantizationType { NONE = 0, UINT8 };
32 
33 // API for accessing parameters for a feed-forward neural network with
34 // embeddings.
35 //
36 // Note: this API is closely related to embedding-network.proto.  The reason we
37 // have a separate API is that the proto may not be the only way of packaging
38 // these parameters.
39 class EmbeddingNetworkParams {
40  public:
~EmbeddingNetworkParams()41   virtual ~EmbeddingNetworkParams() {}
42 
43   // **** High-level API.
44 
45   // Simple representation of a matrix.  This small struct that doesn't own any
46   // resource intentionally supports copy / assign, to simplify our APIs.
47   struct Matrix {
48     // Number of rows.
49     int rows;
50 
51     // Number of columns.
52     int cols;
53 
54     QuantizationType quant_type;
55 
56     // Pointer to matrix elements, in row-major order
57     // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
58     const void *elements;
59 
60     // Quantization scales: one scale for each row.
61     const float16 *quant_scales;
62   };
63 
64   // Returns number of embedding spaces.
GetNumEmbeddingSpaces()65   int GetNumEmbeddingSpaces() const {
66     if (embeddings_size() != embedding_num_features_size()) {
67       TC_LOG(ERROR) << "Embedding spaces mismatch " << embeddings_size()
68                     << " != " << embedding_num_features_size();
69     }
70     return std::max(0,
71                     std::min(embeddings_size(), embedding_num_features_size()));
72   }
73 
74   // Returns embedding matrix for the i-th embedding space.
75   //
76   // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
77   // otherwise.
GetEmbeddingMatrix(int i)78   Matrix GetEmbeddingMatrix(int i) const {
79     TC_DCHECK(InRange(i, embeddings_size()));
80     Matrix matrix;
81     matrix.rows = embeddings_num_rows(i);
82     matrix.cols = embeddings_num_cols(i);
83     matrix.elements = embeddings_weights(i);
84     matrix.quant_type = embeddings_quant_type(i);
85     matrix.quant_scales = embeddings_quant_scales(i);
86     return matrix;
87   }
88 
89   // Returns number of features in i-th embedding space.
90   //
91   // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
92   // otherwise.
GetNumFeaturesInEmbeddingSpace(int i)93   int GetNumFeaturesInEmbeddingSpace(int i) const {
94     TC_DCHECK(InRange(i, embedding_num_features_size()));
95     return std::max(0, embedding_num_features(i));
96   }
97 
98   // Returns number of hidden layers in the neural network.  Each such layer has
99   // weight matrix and a bias vector (a matrix with one column).
GetNumHiddenLayers()100   int GetNumHiddenLayers() const {
101     if (hidden_size() != hidden_bias_size()) {
102       TC_LOG(ERROR) << "Hidden layer mismatch " << hidden_size()
103                     << " != " << hidden_bias_size();
104     }
105     return std::max(0, std::min(hidden_size(), hidden_bias_size()));
106   }
107 
108   // Returns weight matrix for i-th hidden layer.
109   //
110   // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
111   // otherwise.
GetHiddenLayerMatrix(int i)112   Matrix GetHiddenLayerMatrix(int i) const {
113     TC_DCHECK(InRange(i, hidden_size()));
114     Matrix matrix;
115     matrix.rows = hidden_num_rows(i);
116     matrix.cols = hidden_num_cols(i);
117 
118     // Quantization not supported here.
119     matrix.quant_type = QuantizationType::NONE;
120     matrix.elements = hidden_weights(i);
121     return matrix;
122   }
123 
124   // Returns bias matrix for i-th hidden layer.  Technically a Matrix, but we
125   // expect it to be a vector (i.e., num cols is 1).
126   //
127   // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
128   // otherwise.
GetHiddenLayerBias(int i)129   Matrix GetHiddenLayerBias(int i) const {
130     TC_DCHECK(InRange(i, hidden_bias_size()));
131     Matrix matrix;
132     matrix.rows = hidden_bias_num_rows(i);
133     matrix.cols = hidden_bias_num_cols(i);
134 
135     // Quantization not supported here.
136     matrix.quant_type = QuantizationType::NONE;
137     matrix.elements = hidden_bias_weights(i);
138     return matrix;
139   }
140 
141   // Returns true if a softmax layer exists.
HasSoftmaxLayer()142   bool HasSoftmaxLayer() const {
143     if (softmax_size() != softmax_bias_size()) {
144       TC_LOG(ERROR) << "Softmax layer mismatch " << softmax_size()
145                     << " != " << softmax_bias_size();
146     }
147     return (softmax_size() == 1) && (softmax_bias_size() == 1);
148   }
149 
150   // Returns weight matrix for the softmax layer.
151   //
152   // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
153   // behavior otherwise.
GetSoftmaxMatrix()154   Matrix GetSoftmaxMatrix() const {
155     TC_DCHECK(softmax_size() == 1);
156     Matrix matrix;
157     matrix.rows = softmax_num_rows(0);
158     matrix.cols = softmax_num_cols(0);
159 
160     // Quantization not supported here.
161     matrix.quant_type = QuantizationType::NONE;
162     matrix.elements = softmax_weights(0);
163     return matrix;
164   }
165 
166   // Returns bias for the softmax layer.  Technically a Matrix, but we expect it
167   // to be a row/column vector (i.e., num cols is 1).
168   //
169   // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
170   // behavior otherwise.
GetSoftmaxBias()171   Matrix GetSoftmaxBias() const {
172     TC_DCHECK(softmax_bias_size() == 1);
173     Matrix matrix;
174     matrix.rows = softmax_bias_num_rows(0);
175     matrix.cols = softmax_bias_num_cols(0);
176 
177     // Quantization not supported here.
178     matrix.quant_type = QuantizationType::NONE;
179     matrix.elements = softmax_bias_weights(0);
180     return matrix;
181   }
182 
183   // Updates the EmbeddingNetwork-related parameters from task_context.  Returns
184   // true on success, false on error.
UpdateTaskContextParameters(TaskContext * task_context)185   virtual bool UpdateTaskContextParameters(TaskContext *task_context) {
186     const TaskSpec *task_spec = GetTaskSpec();
187     if (task_spec == nullptr) {
188       TC_LOG(ERROR) << "Unable to get TaskSpec";
189       return false;
190     }
191     for (const TaskSpec::Parameter &parameter : task_spec->parameter()) {
192       task_context->SetParameter(parameter.name(), parameter.value());
193     }
194     return true;
195   }
196 
197   // Returns a pointer to a TaskSpec with the EmbeddingNetwork-related
198   // parameters.  Returns nullptr in case of problems.  Ownership with the
199   // returned pointer is *not* transfered to the caller.
GetTaskSpec()200   virtual const TaskSpec *GetTaskSpec() {
201     TC_LOG(ERROR) << "Not implemented";
202     return nullptr;
203   }
204 
205  protected:
206   // **** Low-level API.
207   //
208   // * Most low-level API methods are documented by giving an equivalent
209   //   function call on proto, the original proto (of type
210   //   EmbeddingNetworkProto) which was used to generate the C++ code.
211   //
212   // * To simplify our generation code, optional proto fields of message type
213   //   are treated as repeated fields with 0 or 1 instances.  As such, we have
214   //   *_size() methods for such optional fields: they return 0 or 1.
215   //
216   // * "transpose(M)" denotes the transpose of a matrix M.
217   //
218   // * Behavior is undefined when trying to retrieve a piece of data that does
219   //   not exist: e.g., embeddings_num_rows(5) if embeddings_size() == 2.
220 
221   // ** Access methods for repeated MatrixParams embeddings.
222   //
223   // Returns proto.embeddings_size().
224   virtual int embeddings_size() const = 0;
225 
226   // Returns number of rows of transpose(proto.embeddings(i)).
227   virtual int embeddings_num_rows(int i) const = 0;
228 
229   // Returns number of columns of transpose(proto.embeddings(i)).
230   virtual int embeddings_num_cols(int i) const = 0;
231 
232   // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
233   // order.  NOTE: for unquantized embeddings, this returns a pointer to float;
234   // for quantized embeddings, this returns a pointer to uint8.
235   virtual const void *embeddings_weights(int i) const = 0;
236 
embeddings_quant_type(int i)237   virtual QuantizationType embeddings_quant_type(int i) const {
238     return QuantizationType::NONE;
239   }
240 
embeddings_quant_scales(int i)241   virtual const float16 *embeddings_quant_scales(int i) const {
242     return nullptr;
243   }
244 
245   // ** Access methods for repeated MatrixParams hidden.
246   //
247   // Returns embedding_network_proto.hidden_size().
248   virtual int hidden_size() const = 0;
249 
250   // Returns embedding_network_proto.hidden(i).rows().
251   virtual int hidden_num_rows(int i) const = 0;
252 
253   // Returns embedding_network_proto.hidden(i).rows().
254   virtual int hidden_num_cols(int i) const = 0;
255 
256   // Returns pointer to beginning of array of floats with all values from
257   // embedding_network_proto.hidden(i).
258   virtual const void *hidden_weights(int i) const = 0;
259 
260   // ** Access methods for repeated MatrixParams hidden_bias.
261   //
262   // Returns proto.hidden_bias_size().
263   virtual int hidden_bias_size() const = 0;
264 
265   // Returns number of rows of proto.hidden_bias(i).
266   virtual int hidden_bias_num_rows(int i) const = 0;
267 
268   // Returns number of columns of proto.hidden_bias(i).
269   virtual int hidden_bias_num_cols(int i) const = 0;
270 
271   // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
272   virtual const void *hidden_bias_weights(int i) const = 0;
273 
274   // ** Access methods for optional MatrixParams softmax.
275   //
276   // Returns 1 if proto has optional field softmax, 0 otherwise.
277   virtual int softmax_size() const = 0;
278 
279   // Returns number of rows of transpose(proto.softmax()).
280   virtual int softmax_num_rows(int i) const = 0;
281 
282   // Returns number of columns of transpose(proto.softmax()).
283   virtual int softmax_num_cols(int i) const = 0;
284 
285   // Returns pointer to elements of transpose(proto.softmax()), in row-major
286   // order.
287   virtual const void *softmax_weights(int i) const = 0;
288 
289   // ** Access methods for optional MatrixParams softmax_bias.
290   //
291   // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
292   virtual int softmax_bias_size() const = 0;
293 
294   // Returns number of rows of proto.softmax_bias().
295   virtual int softmax_bias_num_rows(int i) const = 0;
296 
297   // Returns number of columns of proto.softmax_bias().
298   virtual int softmax_bias_num_cols(int i) const = 0;
299 
300   // Returns pointer to elements of proto.softmax_bias(), in row-major order.
301   virtual const void *softmax_bias_weights(int i) const = 0;
302 
303   // ** Access methods for repeated int32 embedding_num_features.
304   //
305   // Returns proto.embedding_num_features_size().
306   virtual int embedding_num_features_size() const = 0;
307 
308   // Returns proto.embedding_num_features(i).
309   virtual int embedding_num_features(int i) const = 0;
310 
311   // Returns true if and only if index is in range [0, size).  Log an error
312   // message otherwise.
InRange(int index,int size)313   static bool InRange(int index, int size) {
314     if ((index < 0) || (index >= size)) {
315       TC_LOG(ERROR) << "Index " << index << " outside [0, " << size << ")";
316       return false;
317     }
318     return true;
319   }
320 };  // class EmbeddingNetworkParams
321 
322 }  // namespace nlp_core
323 }  // namespace libtextclassifier
324 
325 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
326