• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
18 #define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
19 
20 #include "common/embedding-network-package.pb.h"
21 #include "common/embedding-network-params.h"
22 #include "common/embedding-network.pb.h"
23 #include "common/memory_image/memory-image-reader.h"
24 #include "util/base/integral_types.h"
25 
26 namespace libtextclassifier {
27 namespace nlp_core {
28 
29 // EmbeddingNetworkParams backed by a memory image.
30 //
31 // In this context, a memory image is like an EmbeddingNetworkProto, but with
32 // all repeated weights (>99% of the size) directly usable (with no parsing
33 // required).
34 class EmbeddingNetworkParamsFromImage : public EmbeddingNetworkParams {
35  public:
36   // Constructs an EmbeddingNetworkParamsFromImage, using the memory image that
37   // starts at address start and contains num_bytes bytes.
EmbeddingNetworkParamsFromImage(const void * start,uint64 num_bytes)38   EmbeddingNetworkParamsFromImage(const void *start, uint64 num_bytes)
39       : memory_reader_(start, num_bytes),
40         trimmed_proto_(memory_reader_.trimmed_proto()) {
41     embeddings_blob_offset_ = 0;
42 
43     hidden_blob_offset_ = embeddings_blob_offset_ + embeddings_size();
44     if (trimmed_proto_.embeddings_size() &&
45         trimmed_proto_.embeddings(0).is_quantized()) {
46       // Adjust for quantization: each quantized matrix takes two blobs (instead
47       // of one): one for the quantized values and one for the scales.
48       hidden_blob_offset_ += embeddings_size();
49     }
50 
51     hidden_bias_blob_offset_ = hidden_blob_offset_ + hidden_size();
52     softmax_blob_offset_ = hidden_bias_blob_offset_ + hidden_bias_size();
53     softmax_bias_blob_offset_ = softmax_blob_offset_ + softmax_size();
54   }
55 
~EmbeddingNetworkParamsFromImage()56   ~EmbeddingNetworkParamsFromImage() override {}
57 
GetTaskSpec()58   const TaskSpec *GetTaskSpec() override {
59     auto extension_id = task_spec_in_embedding_network_proto;
60     if (trimmed_proto_.HasExtension(extension_id)) {
61       return &(trimmed_proto_.GetExtension(extension_id));
62     } else {
63       return nullptr;
64     }
65   }
66 
67  protected:
embeddings_size()68   int embeddings_size() const override {
69     return trimmed_proto_.embeddings_size();
70   }
71 
embeddings_num_rows(int i)72   int embeddings_num_rows(int i) const override {
73     TC_DCHECK(InRange(i, embeddings_size()));
74     return trimmed_proto_.embeddings(i).rows();
75   }
76 
embeddings_num_cols(int i)77   int embeddings_num_cols(int i) const override {
78     TC_DCHECK(InRange(i, embeddings_size()));
79     return trimmed_proto_.embeddings(i).cols();
80   }
81 
embeddings_weights(int i)82   const void *embeddings_weights(int i) const override {
83     TC_DCHECK(InRange(i, embeddings_size()));
84     const int blob_index = trimmed_proto_.embeddings(i).is_quantized()
85                                ? (embeddings_blob_offset_ + 2 * i)
86                                : (embeddings_blob_offset_ + i);
87     DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
88     return data_blob_view.data();
89   }
90 
embeddings_quant_type(int i)91   QuantizationType embeddings_quant_type(int i) const override {
92     TC_DCHECK(InRange(i, embeddings_size()));
93     if (trimmed_proto_.embeddings(i).is_quantized()) {
94       return QuantizationType::UINT8;
95     } else {
96       return QuantizationType::NONE;
97     }
98   }
99 
embeddings_quant_scales(int i)100   const float16 *embeddings_quant_scales(int i) const override {
101     TC_DCHECK(InRange(i, embeddings_size()));
102     if (trimmed_proto_.embeddings(i).is_quantized()) {
103       // Each embedding matrix has two atttached data blobs (hence the "2 * i"):
104       // one blob with the quantized values and (immediately after it, hence the
105       // "+ 1") one blob with the scales.
106       int blob_index = embeddings_blob_offset_ + 2 * i + 1;
107       DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
108       return reinterpret_cast<const float16 *>(data_blob_view.data());
109     } else {
110       return nullptr;
111     }
112   }
113 
hidden_size()114   int hidden_size() const override { return trimmed_proto_.hidden_size(); }
115 
hidden_num_rows(int i)116   int hidden_num_rows(int i) const override {
117     TC_DCHECK(InRange(i, hidden_size()));
118     return trimmed_proto_.hidden(i).rows();
119   }
120 
hidden_num_cols(int i)121   int hidden_num_cols(int i) const override {
122     TC_DCHECK(InRange(i, hidden_size()));
123     return trimmed_proto_.hidden(i).cols();
124   }
125 
hidden_weights(int i)126   const void *hidden_weights(int i) const override {
127     TC_DCHECK(InRange(i, hidden_size()));
128     DataBlobView data_blob_view =
129         memory_reader_.data_blob_view(hidden_blob_offset_ + i);
130     return data_blob_view.data();
131   }
132 
hidden_bias_size()133   int hidden_bias_size() const override {
134     return trimmed_proto_.hidden_bias_size();
135   }
136 
hidden_bias_num_rows(int i)137   int hidden_bias_num_rows(int i) const override {
138     TC_DCHECK(InRange(i, hidden_bias_size()));
139     return trimmed_proto_.hidden_bias(i).rows();
140   }
141 
hidden_bias_num_cols(int i)142   int hidden_bias_num_cols(int i) const override {
143     TC_DCHECK(InRange(i, hidden_bias_size()));
144     return trimmed_proto_.hidden_bias(i).cols();
145   }
146 
hidden_bias_weights(int i)147   const void *hidden_bias_weights(int i) const override {
148     TC_DCHECK(InRange(i, hidden_bias_size()));
149     DataBlobView data_blob_view =
150         memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i);
151     return data_blob_view.data();
152   }
153 
softmax_size()154   int softmax_size() const override {
155     return trimmed_proto_.has_softmax() ? 1 : 0;
156   }
157 
softmax_num_rows(int i)158   int softmax_num_rows(int i) const override {
159     TC_DCHECK(InRange(i, softmax_size()));
160     return trimmed_proto_.softmax().rows();
161   }
162 
softmax_num_cols(int i)163   int softmax_num_cols(int i) const override {
164     TC_DCHECK(InRange(i, softmax_size()));
165     return trimmed_proto_.softmax().cols();
166   }
167 
softmax_weights(int i)168   const void *softmax_weights(int i) const override {
169     TC_DCHECK(InRange(i, softmax_size()));
170     DataBlobView data_blob_view =
171         memory_reader_.data_blob_view(softmax_blob_offset_ + i);
172     return data_blob_view.data();
173   }
174 
softmax_bias_size()175   int softmax_bias_size() const override {
176     return trimmed_proto_.has_softmax_bias() ? 1 : 0;
177   }
178 
softmax_bias_num_rows(int i)179   int softmax_bias_num_rows(int i) const override {
180     TC_DCHECK(InRange(i, softmax_bias_size()));
181     return trimmed_proto_.softmax_bias().rows();
182   }
183 
softmax_bias_num_cols(int i)184   int softmax_bias_num_cols(int i) const override {
185     TC_DCHECK(InRange(i, softmax_bias_size()));
186     return trimmed_proto_.softmax_bias().cols();
187   }
188 
softmax_bias_weights(int i)189   const void *softmax_bias_weights(int i) const override {
190     TC_DCHECK(InRange(i, softmax_bias_size()));
191     DataBlobView data_blob_view =
192         memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i);
193     return data_blob_view.data();
194   }
195 
embedding_num_features_size()196   int embedding_num_features_size() const override {
197     return trimmed_proto_.embedding_num_features_size();
198   }
199 
embedding_num_features(int i)200   int embedding_num_features(int i) const override {
201     TC_DCHECK(InRange(i, embedding_num_features_size()));
202     return trimmed_proto_.embedding_num_features(i);
203   }
204 
205  private:
206   MemoryImageReader<EmbeddingNetworkProto> memory_reader_;
207 
208   const EmbeddingNetworkProto &trimmed_proto_;
209 
210   // 0-based offsets in the list of data blobs for the different MatrixParams
211   // fields.  E.g., the 1st hidden MatrixParams has its weights stored in the
212   // data blob number hidden_blob_offset_, the 2nd one in hidden_blob_offset_ +
213   // 1, and so on.
214   int embeddings_blob_offset_;
215   int hidden_blob_offset_;
216   int hidden_bias_blob_offset_;
217   int softmax_blob_offset_;
218   int softmax_bias_blob_offset_;
219 };
220 
221 }  // namespace nlp_core
222 }  // namespace libtextclassifier
223 
224 #endif  // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
225