• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
19 
20 #include <vector>
21 
22 #include "lang_id/common/embedding-network-params.h"
23 #include "lang_id/common/fel/feature-extractor.h"
24 
25 namespace libtextclassifier3 {
26 namespace mobile {
27 
28 // Classifier using a hand-coded feed-forward neural network.
29 //
30 // No gradient computation, just inference.
31 //
32 // Based on the more general nlp_saft::EmbeddingNetwork (without ::mobile).
33 //
34 // Classification works as follows:
35 //
36 // Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
37 //
38 // In words: given some discrete features, this class extracts the embeddings
39 // for these features, concatenates them, passes them through one or more hidden
40 // layers (each layer uses Relu) and next through a softmax layer that computes
41 // an unnormalized score for each possible class.  Note: there is always a
42 // softmax layer at the end.
43 class EmbeddingNetwork {
44  public:
45   // Constructs an embedding network using the parameters from model.
46   //
47   // Note: model should stay alive for at least the lifetime of this
48   // EmbeddingNetwork object.
49   explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
50 
~EmbeddingNetwork()51   virtual ~EmbeddingNetwork() {}
52 
53   // Runs forward computation to fill scores with unnormalized output unit
54   // scores. This is useful for making predictions.
55   void ComputeFinalScores(const std::vector<FeatureVector> &features,
56                           std::vector<float> *scores) const;
57 
58   // Same as above, but allows specification of extra extra neural network
59   // inputs that will be appended to the embedding vector build from features.
60   void ComputeFinalScores(const std::vector<FeatureVector> &features,
61                           const std::vector<float> &extra_inputs,
62                           std::vector<float> *scores) const;
63 
64  private:
65   // Constructs the concatenated input embedding vector in place in output
66   // vector concat.
67   void ConcatEmbeddings(const std::vector<FeatureVector> &features,
68                         std::vector<float> *concat) const;
69 
70   // Pointer to the model object passed to the constructor.  Not owned.
71   const EmbeddingNetworkParams *model_;
72 
73   // Network parameters.
74 
75   // One weight matrix for each embedding.
76   std::vector<EmbeddingNetworkParams::Matrix> embedding_matrices_;
77 
78   // embedding_row_size_in_bytes_[i] is the size (in bytes) of a row from
79   // embedding_matrices_[i].  We precompute this in order to quickly find the
80   // beginning of the k-th row from an embedding matrix (which is stored in
81   // row-major order).
82   std::vector<int> embedding_row_size_in_bytes_;
83 
84   // concat_offset_[i] is the input layer offset for i-th embedding space.
85   std::vector<int> concat_offset_;
86 
87   // Size of the input ("concatenation") layer.
88   int concat_layer_size_ = 0;
89 
90   // One weight matrix and one vector of bias weights for each layer of neurons.
91   // Last layer is the softmax layer, the previous ones are the hidden layers.
92   std::vector<EmbeddingNetworkParams::Matrix> layer_weights_;
93   std::vector<EmbeddingNetworkParams::Matrix> layer_bias_;
94 };
95 
96 }  // namespace mobile
97 }  // namespace nlp_saft
98 
99 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
100