• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "common/embedding-network-package.pb.h"
27 #include "common/embedding-network-params.h"
28 #include "common/embedding-network.pb.h"
29 #include "common/float16.h"
30 #include "common/little-endian-data.h"
31 #include "common/task-context.h"
32 #include "common/task-spec.pb.h"
33 #include "util/base/integral_types.h"
34 #include "util/base/logging.h"
35 
36 namespace libtextclassifier {
37 namespace nlp_core {
38 
39 // A wrapper class that owns and exposes an EmbeddingNetworkProto message via
40 // the EmbeddingNetworkParams interface.
41 //
42 // The EmbeddingNetworkParams interface encapsulates the weight matrices of the
43 // embeddings, hidden and softmax layers as transposed versions of their
44 // counterparts in the original EmbeddingNetworkProto. The matrices in the proto
45 // passed to this class' constructor must likewise already have been transposed.
46 // See embedding-network-params.h for details.
47 class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams {
48  public:
49   // Constructor that takes ownership of the provided proto. See class-comment
50   // for the requirements that certain weight matrices must satisfy.
EmbeddingNetworkParamsFromProto(std::unique_ptr<EmbeddingNetworkProto> proto)51   explicit EmbeddingNetworkParamsFromProto(
52       std::unique_ptr<EmbeddingNetworkProto> proto)
53       : proto_(std::move(proto)) {
54     valid_ = true;
55 
56     // Initialize these vectors to have the required number of elements
57     // regardless of quantization status. This is to support the unlikely case
58     // where only some embeddings are quantized, along with the fact that
59     // EmbeddingNetworkParams interface accesses them by index.
60     embeddings_quant_scales_.resize(proto_->embeddings_size());
61     embeddings_quant_weights_.resize(proto_->embeddings_size());
62     for (int i = 0; i < proto_->embeddings_size(); ++i) {
63       MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i);
64       if (!embedding->is_quantized()) {
65         continue;
66       }
67 
68       bool success = FillVectorFromDataBytesInLittleEndian(
69           embedding->bytes_for_quantized_values(),
70           embedding->rows() * embedding->cols(),
71           &(embeddings_quant_weights_[i]));
72       if (!success) {
73         TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i;
74         valid_ = false;
75       }
76 
77       // The repeated field bytes_for_quantized_values uses a lot of memory.
78       // Since it's no longer necessary (and we own the proto), we clear it.
79       embedding->clear_bytes_for_quantized_values();
80 
81       success = FillVectorFromDataBytesInLittleEndian(
82           embedding->bytes_for_col_scales(),
83           embedding->rows(),
84           &(embeddings_quant_scales_[i]));
85       if (!success) {
86         TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i;
87         valid_ = false;
88       }
89 
90       // See comments for clear_bytes_for_quantized_values().
91       embedding->clear_bytes_for_col_scales();
92     }
93   }
94 
GetTaskSpec()95   const TaskSpec *GetTaskSpec() override {
96     if (!proto_) {
97       return nullptr;
98     }
99     auto extension_id = task_spec_in_embedding_network_proto;
100     if (proto_->HasExtension(extension_id)) {
101       return &(proto_->GetExtension(extension_id));
102     } else {
103       TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto";
104       return nullptr;
105     }
106   }
107 
108   // Returns true if these params are valid.  False otherwise (e.g., if the
109   // original proto data was corrupted).
is_valid()110   bool is_valid() { return valid_; }
111 
112  protected:
embeddings_size()113   int embeddings_size() const override { return proto_->embeddings_size(); }
114 
embeddings_num_rows(int i)115   int embeddings_num_rows(int i) const override {
116     TC_DCHECK(InRange(i, embeddings_size()));
117     return proto_->embeddings(i).rows();
118   }
119 
embeddings_num_cols(int i)120   int embeddings_num_cols(int i) const override {
121     TC_DCHECK(InRange(i, embeddings_size()));
122     return proto_->embeddings(i).cols();
123   }
124 
embeddings_weights(int i)125   const void *embeddings_weights(int i) const override {
126     TC_DCHECK(InRange(i, embeddings_size()));
127     if (proto_->embeddings(i).is_quantized()) {
128       return static_cast<const void *>(embeddings_quant_weights_.at(i).data());
129     } else {
130       return static_cast<const void *>(proto_->embeddings(i).value().data());
131     }
132   }
133 
embeddings_quant_type(int i)134   QuantizationType embeddings_quant_type(int i) const override {
135     TC_DCHECK(InRange(i, embeddings_size()));
136     return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8
137                                                 : QuantizationType::NONE;
138   }
139 
embeddings_quant_scales(int i)140   const float16 *embeddings_quant_scales(int i) const override {
141     TC_DCHECK(InRange(i, embeddings_size()));
142     return proto_->embeddings(i).is_quantized()
143                ? embeddings_quant_scales_.at(i).data()
144                : nullptr;
145   }
146 
hidden_size()147   int hidden_size() const override { return proto_->hidden_size(); }
148 
hidden_num_rows(int i)149   int hidden_num_rows(int i) const override {
150     TC_DCHECK(InRange(i, hidden_size()));
151     return proto_->hidden(i).rows();
152   }
153 
hidden_num_cols(int i)154   int hidden_num_cols(int i) const override {
155     TC_DCHECK(InRange(i, hidden_size()));
156     return proto_->hidden(i).cols();
157   }
158 
hidden_weights(int i)159   const void *hidden_weights(int i) const override {
160     TC_DCHECK(InRange(i, hidden_size()));
161     return proto_->hidden(i).value().data();
162   }
163 
hidden_bias_size()164   int hidden_bias_size() const override { return proto_->hidden_bias_size(); }
165 
hidden_bias_num_rows(int i)166   int hidden_bias_num_rows(int i) const override {
167     TC_DCHECK(InRange(i, hidden_bias_size()));
168     return proto_->hidden_bias(i).rows();
169   }
170 
hidden_bias_num_cols(int i)171   int hidden_bias_num_cols(int i) const override {
172     TC_DCHECK(InRange(i, hidden_bias_size()));
173     return proto_->hidden_bias(i).cols();
174   }
175 
hidden_bias_weights(int i)176   const void *hidden_bias_weights(int i) const override {
177     TC_DCHECK(InRange(i, hidden_bias_size()));
178     return proto_->hidden_bias(i).value().data();
179   }
180 
softmax_size()181   int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; }
182 
softmax_num_rows(int i)183   int softmax_num_rows(int i) const override {
184     TC_DCHECK(InRange(i, softmax_size()));
185     return proto_->has_softmax() ? proto_->softmax().rows() : 0;
186   }
187 
softmax_num_cols(int i)188   int softmax_num_cols(int i) const override {
189     TC_DCHECK(InRange(i, softmax_size()));
190     return proto_->has_softmax() ? proto_->softmax().cols() : 0;
191   }
192 
softmax_weights(int i)193   const void *softmax_weights(int i) const override {
194     TC_DCHECK(InRange(i, softmax_size()));
195     return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr;
196   }
197 
softmax_bias_size()198   int softmax_bias_size() const override {
199     return proto_->has_softmax_bias() ? 1 : 0;
200   }
201 
softmax_bias_num_rows(int i)202   int softmax_bias_num_rows(int i) const override {
203     TC_DCHECK(InRange(i, softmax_bias_size()));
204     return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0;
205   }
206 
softmax_bias_num_cols(int i)207   int softmax_bias_num_cols(int i) const override {
208     TC_DCHECK(InRange(i, softmax_bias_size()));
209     return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0;
210   }
211 
softmax_bias_weights(int i)212   const void *softmax_bias_weights(int i) const override {
213     TC_DCHECK(InRange(i, softmax_bias_size()));
214     return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data()
215                                       : nullptr;
216   }
217 
embedding_num_features_size()218   int embedding_num_features_size() const override {
219     return proto_->embedding_num_features_size();
220   }
221 
embedding_num_features(int i)222   int embedding_num_features(int i) const override {
223     TC_DCHECK(InRange(i, embedding_num_features_size()));
224     return proto_->embedding_num_features(i);
225   }
226 
227  private:
228   std::unique_ptr<EmbeddingNetworkProto> proto_;
229 
230   // True if these params are valid.  May be false if the original proto was
231   // corrupted.  We prefer to set this to false to CHECK-failing.
232   bool valid_;
233 
234   // When the embeddings are quantized, these members are used to store their
235   // numeric values using the types expected by the rest of the class. Due to
236   // technical reasons, the proto stores this info using larger types (i.e.,
237   // more bits).
238   std::vector<std::vector<float16>> embeddings_quant_scales_;
239   std::vector<std::vector<uint8>> embeddings_quant_weights_;
240 };
241 
242 }  // namespace nlp_core
243 }  // namespace libtextclassifier
244 
245 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
246