/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_ #include #include #include #include #include #include "common/embedding-network-package.pb.h" #include "common/embedding-network-params.h" #include "common/embedding-network.pb.h" #include "common/float16.h" #include "common/little-endian-data.h" #include "common/task-context.h" #include "common/task-spec.pb.h" #include "util/base/integral_types.h" #include "util/base/logging.h" namespace libtextclassifier { namespace nlp_core { // A wrapper class that owns and exposes an EmbeddingNetworkProto message via // the EmbeddingNetworkParams interface. // // The EmbeddingNetworkParams interface encapsulates the weight matrices of the // embeddings, hidden and softmax layers as transposed versions of their // counterparts in the original EmbeddingNetworkProto. The matrices in the proto // passed to this class' constructor must likewise already have been transposed. // See embedding-network-params.h for details. class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams { public: // Constructor that takes ownership of the provided proto. See class-comment // for the requirements that certain weight matrices must satisfy. explicit EmbeddingNetworkParamsFromProto( std::unique_ptr proto) : proto_(std::move(proto)) { valid_ = true; // Initialize these vectors to have the required number of elements // regardless of quantization status. This is to support the unlikely case // where only some embeddings are quantized, along with the fact that // EmbeddingNetworkParams interface accesses them by index. embeddings_quant_scales_.resize(proto_->embeddings_size()); embeddings_quant_weights_.resize(proto_->embeddings_size()); for (int i = 0; i < proto_->embeddings_size(); ++i) { MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i); if (!embedding->is_quantized()) { continue; } bool success = FillVectorFromDataBytesInLittleEndian( embedding->bytes_for_quantized_values(), embedding->rows() * embedding->cols(), &(embeddings_quant_weights_[i])); if (!success) { TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i; valid_ = false; } // The repeated field bytes_for_quantized_values uses a lot of memory. // Since it's no longer necessary (and we own the proto), we clear it. embedding->clear_bytes_for_quantized_values(); success = FillVectorFromDataBytesInLittleEndian( embedding->bytes_for_col_scales(), embedding->rows(), &(embeddings_quant_scales_[i])); if (!success) { TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i; valid_ = false; } // See comments for clear_bytes_for_quantized_values(). embedding->clear_bytes_for_col_scales(); } } const TaskSpec *GetTaskSpec() override { if (!proto_) { return nullptr; } auto extension_id = task_spec_in_embedding_network_proto; if (proto_->HasExtension(extension_id)) { return &(proto_->GetExtension(extension_id)); } else { TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto"; return nullptr; } } // Returns true if these params are valid. False otherwise (e.g., if the // original proto data was corrupted). bool is_valid() { return valid_; } protected: int embeddings_size() const override { return proto_->embeddings_size(); } int embeddings_num_rows(int i) const override { TC_DCHECK(InRange(i, embeddings_size())); return proto_->embeddings(i).rows(); } int embeddings_num_cols(int i) const override { TC_DCHECK(InRange(i, embeddings_size())); return proto_->embeddings(i).cols(); } const void *embeddings_weights(int i) const override { TC_DCHECK(InRange(i, embeddings_size())); if (proto_->embeddings(i).is_quantized()) { return static_cast(embeddings_quant_weights_.at(i).data()); } else { return static_cast(proto_->embeddings(i).value().data()); } } QuantizationType embeddings_quant_type(int i) const override { TC_DCHECK(InRange(i, embeddings_size())); return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8 : QuantizationType::NONE; } const float16 *embeddings_quant_scales(int i) const override { TC_DCHECK(InRange(i, embeddings_size())); return proto_->embeddings(i).is_quantized() ? embeddings_quant_scales_.at(i).data() : nullptr; } int hidden_size() const override { return proto_->hidden_size(); } int hidden_num_rows(int i) const override { TC_DCHECK(InRange(i, hidden_size())); return proto_->hidden(i).rows(); } int hidden_num_cols(int i) const override { TC_DCHECK(InRange(i, hidden_size())); return proto_->hidden(i).cols(); } const void *hidden_weights(int i) const override { TC_DCHECK(InRange(i, hidden_size())); return proto_->hidden(i).value().data(); } int hidden_bias_size() const override { return proto_->hidden_bias_size(); } int hidden_bias_num_rows(int i) const override { TC_DCHECK(InRange(i, hidden_bias_size())); return proto_->hidden_bias(i).rows(); } int hidden_bias_num_cols(int i) const override { TC_DCHECK(InRange(i, hidden_bias_size())); return proto_->hidden_bias(i).cols(); } const void *hidden_bias_weights(int i) const override { TC_DCHECK(InRange(i, hidden_bias_size())); return proto_->hidden_bias(i).value().data(); } int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; } int softmax_num_rows(int i) const override { TC_DCHECK(InRange(i, softmax_size())); return proto_->has_softmax() ? proto_->softmax().rows() : 0; } int softmax_num_cols(int i) const override { TC_DCHECK(InRange(i, softmax_size())); return proto_->has_softmax() ? proto_->softmax().cols() : 0; } const void *softmax_weights(int i) const override { TC_DCHECK(InRange(i, softmax_size())); return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr; } int softmax_bias_size() const override { return proto_->has_softmax_bias() ? 1 : 0; } int softmax_bias_num_rows(int i) const override { TC_DCHECK(InRange(i, softmax_bias_size())); return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0; } int softmax_bias_num_cols(int i) const override { TC_DCHECK(InRange(i, softmax_bias_size())); return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0; } const void *softmax_bias_weights(int i) const override { TC_DCHECK(InRange(i, softmax_bias_size())); return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data() : nullptr; } int embedding_num_features_size() const override { return proto_->embedding_num_features_size(); } int embedding_num_features(int i) const override { TC_DCHECK(InRange(i, embedding_num_features_size())); return proto_->embedding_num_features(i); } private: std::unique_ptr proto_; // True if these params are valid. May be false if the original proto was // corrupted. We prefer to set this to false to CHECK-failing. bool valid_; // When the embeddings are quantized, these members are used to store their // numeric values using the types expected by the rest of the class. Due to // technical reasons, the proto stores this info using larger types (i.e., // more bits). std::vector> embeddings_quant_scales_; std::vector> embeddings_quant_weights_; }; } // namespace nlp_core } // namespace libtextclassifier #endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_