1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ 18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "common/embedding-network-package.pb.h" 27 #include "common/embedding-network-params.h" 28 #include "common/embedding-network.pb.h" 29 #include "common/float16.h" 30 #include "common/little-endian-data.h" 31 #include "common/task-context.h" 32 #include "common/task-spec.pb.h" 33 #include "util/base/integral_types.h" 34 #include "util/base/logging.h" 35 36 namespace libtextclassifier { 37 namespace nlp_core { 38 39 // A wrapper class that owns and exposes an EmbeddingNetworkProto message via 40 // the EmbeddingNetworkParams interface. 41 // 42 // The EmbeddingNetworkParams interface encapsulates the weight matrices of the 43 // embeddings, hidden and softmax layers as transposed versions of their 44 // counterparts in the original EmbeddingNetworkProto. The matrices in the proto 45 // passed to this class' constructor must likewise already have been transposed. 46 // See embedding-network-params.h for details. 47 class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams { 48 public: 49 // Constructor that takes ownership of the provided proto. See class-comment 50 // for the requirements that certain weight matrices must satisfy. EmbeddingNetworkParamsFromProto(std::unique_ptr<EmbeddingNetworkProto> proto)51 explicit EmbeddingNetworkParamsFromProto( 52 std::unique_ptr<EmbeddingNetworkProto> proto) 53 : proto_(std::move(proto)) { 54 valid_ = true; 55 56 // Initialize these vectors to have the required number of elements 57 // regardless of quantization status. This is to support the unlikely case 58 // where only some embeddings are quantized, along with the fact that 59 // EmbeddingNetworkParams interface accesses them by index. 60 embeddings_quant_scales_.resize(proto_->embeddings_size()); 61 embeddings_quant_weights_.resize(proto_->embeddings_size()); 62 for (int i = 0; i < proto_->embeddings_size(); ++i) { 63 MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i); 64 if (!embedding->is_quantized()) { 65 continue; 66 } 67 68 bool success = FillVectorFromDataBytesInLittleEndian( 69 embedding->bytes_for_quantized_values(), 70 embedding->rows() * embedding->cols(), 71 &(embeddings_quant_weights_[i])); 72 if (!success) { 73 TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i; 74 valid_ = false; 75 } 76 77 // The repeated field bytes_for_quantized_values uses a lot of memory. 78 // Since it's no longer necessary (and we own the proto), we clear it. 79 embedding->clear_bytes_for_quantized_values(); 80 81 success = FillVectorFromDataBytesInLittleEndian( 82 embedding->bytes_for_col_scales(), 83 embedding->rows(), 84 &(embeddings_quant_scales_[i])); 85 if (!success) { 86 TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i; 87 valid_ = false; 88 } 89 90 // See comments for clear_bytes_for_quantized_values(). 91 embedding->clear_bytes_for_col_scales(); 92 } 93 } 94 GetTaskSpec()95 const TaskSpec *GetTaskSpec() override { 96 if (!proto_) { 97 return nullptr; 98 } 99 auto extension_id = task_spec_in_embedding_network_proto; 100 if (proto_->HasExtension(extension_id)) { 101 return &(proto_->GetExtension(extension_id)); 102 } else { 103 TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto"; 104 return nullptr; 105 } 106 } 107 108 // Returns true if these params are valid. False otherwise (e.g., if the 109 // original proto data was corrupted). is_valid()110 bool is_valid() { return valid_; } 111 112 protected: embeddings_size()113 int embeddings_size() const override { return proto_->embeddings_size(); } 114 embeddings_num_rows(int i)115 int embeddings_num_rows(int i) const override { 116 TC_DCHECK(InRange(i, embeddings_size())); 117 return proto_->embeddings(i).rows(); 118 } 119 embeddings_num_cols(int i)120 int embeddings_num_cols(int i) const override { 121 TC_DCHECK(InRange(i, embeddings_size())); 122 return proto_->embeddings(i).cols(); 123 } 124 embeddings_weights(int i)125 const void *embeddings_weights(int i) const override { 126 TC_DCHECK(InRange(i, embeddings_size())); 127 if (proto_->embeddings(i).is_quantized()) { 128 return static_cast<const void *>(embeddings_quant_weights_.at(i).data()); 129 } else { 130 return static_cast<const void *>(proto_->embeddings(i).value().data()); 131 } 132 } 133 embeddings_quant_type(int i)134 QuantizationType embeddings_quant_type(int i) const override { 135 TC_DCHECK(InRange(i, embeddings_size())); 136 return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8 137 : QuantizationType::NONE; 138 } 139 embeddings_quant_scales(int i)140 const float16 *embeddings_quant_scales(int i) const override { 141 TC_DCHECK(InRange(i, embeddings_size())); 142 return proto_->embeddings(i).is_quantized() 143 ? embeddings_quant_scales_.at(i).data() 144 : nullptr; 145 } 146 hidden_size()147 int hidden_size() const override { return proto_->hidden_size(); } 148 hidden_num_rows(int i)149 int hidden_num_rows(int i) const override { 150 TC_DCHECK(InRange(i, hidden_size())); 151 return proto_->hidden(i).rows(); 152 } 153 hidden_num_cols(int i)154 int hidden_num_cols(int i) const override { 155 TC_DCHECK(InRange(i, hidden_size())); 156 return proto_->hidden(i).cols(); 157 } 158 hidden_weights(int i)159 const void *hidden_weights(int i) const override { 160 TC_DCHECK(InRange(i, hidden_size())); 161 return proto_->hidden(i).value().data(); 162 } 163 hidden_bias_size()164 int hidden_bias_size() const override { return proto_->hidden_bias_size(); } 165 hidden_bias_num_rows(int i)166 int hidden_bias_num_rows(int i) const override { 167 TC_DCHECK(InRange(i, hidden_bias_size())); 168 return proto_->hidden_bias(i).rows(); 169 } 170 hidden_bias_num_cols(int i)171 int hidden_bias_num_cols(int i) const override { 172 TC_DCHECK(InRange(i, hidden_bias_size())); 173 return proto_->hidden_bias(i).cols(); 174 } 175 hidden_bias_weights(int i)176 const void *hidden_bias_weights(int i) const override { 177 TC_DCHECK(InRange(i, hidden_bias_size())); 178 return proto_->hidden_bias(i).value().data(); 179 } 180 softmax_size()181 int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; } 182 softmax_num_rows(int i)183 int softmax_num_rows(int i) const override { 184 TC_DCHECK(InRange(i, softmax_size())); 185 return proto_->has_softmax() ? proto_->softmax().rows() : 0; 186 } 187 softmax_num_cols(int i)188 int softmax_num_cols(int i) const override { 189 TC_DCHECK(InRange(i, softmax_size())); 190 return proto_->has_softmax() ? proto_->softmax().cols() : 0; 191 } 192 softmax_weights(int i)193 const void *softmax_weights(int i) const override { 194 TC_DCHECK(InRange(i, softmax_size())); 195 return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr; 196 } 197 softmax_bias_size()198 int softmax_bias_size() const override { 199 return proto_->has_softmax_bias() ? 1 : 0; 200 } 201 softmax_bias_num_rows(int i)202 int softmax_bias_num_rows(int i) const override { 203 TC_DCHECK(InRange(i, softmax_bias_size())); 204 return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0; 205 } 206 softmax_bias_num_cols(int i)207 int softmax_bias_num_cols(int i) const override { 208 TC_DCHECK(InRange(i, softmax_bias_size())); 209 return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0; 210 } 211 softmax_bias_weights(int i)212 const void *softmax_bias_weights(int i) const override { 213 TC_DCHECK(InRange(i, softmax_bias_size())); 214 return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data() 215 : nullptr; 216 } 217 embedding_num_features_size()218 int embedding_num_features_size() const override { 219 return proto_->embedding_num_features_size(); 220 } 221 embedding_num_features(int i)222 int embedding_num_features(int i) const override { 223 TC_DCHECK(InRange(i, embedding_num_features_size())); 224 return proto_->embedding_num_features(i); 225 } 226 227 private: 228 std::unique_ptr<EmbeddingNetworkProto> proto_; 229 230 // True if these params are valid. May be false if the original proto was 231 // corrupted. We prefer to set this to false to CHECK-failing. 232 bool valid_; 233 234 // When the embeddings are quantized, these members are used to store their 235 // numeric values using the types expected by the rest of the class. Due to 236 // technical reasons, the proto stores this info using larger types (i.e., 237 // more bits). 238 std::vector<std::vector<float16>> embeddings_quant_scales_; 239 std::vector<std::vector<uint8>> embeddings_quant_weights_; 240 }; 241 242 } // namespace nlp_core 243 } // namespace libtextclassifier 244 245 #endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ 246