1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Model parameter loading. 18 19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_ 20 #define LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_ 21 22 #include "common/embedding-network.h" 23 #include "common/memory_image/embedding-network-params-from-image.h" 24 #include "smartselect/text-classification-model.pb.h" 25 26 namespace libtextclassifier { 27 28 class EmbeddingParams : public nlp_core::EmbeddingNetworkParamsFromImage { 29 public: EmbeddingParams(const void * start,uint64 num_bytes,int context_size)30 EmbeddingParams(const void* start, uint64 num_bytes, int context_size) 31 : EmbeddingNetworkParamsFromImage(start, num_bytes), 32 context_size_(context_size) {} 33 embeddings_size()34 int embeddings_size() const override { return context_size_ * 2 + 1; } 35 embedding_num_features_size()36 int embedding_num_features_size() const override { 37 return context_size_ * 2 + 1; 38 } 39 embedding_num_features(int i)40 int embedding_num_features(int i) const override { return 1; } 41 embeddings_num_rows(int i)42 int embeddings_num_rows(int i) const override { 43 return EmbeddingNetworkParamsFromImage::embeddings_num_rows(0); 44 }; 45 embeddings_num_cols(int i)46 int embeddings_num_cols(int i) const override { 47 return EmbeddingNetworkParamsFromImage::embeddings_num_cols(0); 48 }; 49 embeddings_weights(int i)50 const void* embeddings_weights(int i) const override { 51 return EmbeddingNetworkParamsFromImage::embeddings_weights(0); 52 }; 53 embeddings_quant_type(int i)54 nlp_core::QuantizationType embeddings_quant_type(int i) const override { 55 return EmbeddingNetworkParamsFromImage::embeddings_quant_type(0); 56 } 57 embeddings_quant_scales(int i)58 const nlp_core::float16* embeddings_quant_scales(int i) const override { 59 return EmbeddingNetworkParamsFromImage::embeddings_quant_scales(0); 60 } 61 62 private: 63 int context_size_; 64 }; 65 66 // Loads and holds the parameters of the inference network. 67 // 68 // This class overrides a couple of methods of EmbeddingNetworkParamsFromImage 69 // because we only have one embedding matrix for all positions of context, 70 // whereas the original class would have a separate one for each. 71 class ModelParams : public nlp_core::EmbeddingNetworkParamsFromImage { 72 public: GetFeatureProcessorOptions()73 const FeatureProcessorOptions& GetFeatureProcessorOptions() const { 74 return feature_processor_options_; 75 } 76 GetSelectionModelOptions()77 const SelectionModelOptions& GetSelectionModelOptions() const { 78 return selection_options_; 79 } 80 GetSharingModelOptions()81 const SharingModelOptions& GetSharingModelOptions() const { 82 return sharing_options_; 83 } 84 GetEmbeddingParams()85 std::shared_ptr<EmbeddingParams> GetEmbeddingParams() const { 86 return embedding_params_; 87 } 88 89 protected: embeddings_size()90 int embeddings_size() const override { 91 return embedding_params_->embeddings_size(); 92 } 93 embedding_num_features_size()94 int embedding_num_features_size() const override { 95 return embedding_params_->embedding_num_features_size(); 96 } 97 embedding_num_features(int i)98 int embedding_num_features(int i) const override { 99 return embedding_params_->embedding_num_features(i); 100 } 101 embeddings_num_rows(int i)102 int embeddings_num_rows(int i) const override { 103 return embedding_params_->embeddings_num_rows(i); 104 }; 105 embeddings_num_cols(int i)106 int embeddings_num_cols(int i) const override { 107 return embedding_params_->embeddings_num_cols(i); 108 }; 109 embeddings_weights(int i)110 const void* embeddings_weights(int i) const override { 111 return embedding_params_->embeddings_weights(i); 112 }; 113 embeddings_quant_type(int i)114 nlp_core::QuantizationType embeddings_quant_type(int i) const override { 115 return embedding_params_->embeddings_quant_type(i); 116 } 117 embeddings_quant_scales(int i)118 const nlp_core::float16* embeddings_quant_scales(int i) const override { 119 return embedding_params_->embeddings_quant_scales(i); 120 } 121 122 private: 123 friend ModelParams* ModelParamsBuilder( 124 const void* start, uint64 num_bytes, 125 std::shared_ptr<EmbeddingParams> external_embedding_params); 126 ModelParams(const void * start,uint64 num_bytes,std::shared_ptr<EmbeddingParams> embedding_params,const SelectionModelOptions & selection_options,const SharingModelOptions & sharing_options,const FeatureProcessorOptions & feature_processor_options)127 ModelParams(const void* start, uint64 num_bytes, 128 std::shared_ptr<EmbeddingParams> embedding_params, 129 const SelectionModelOptions& selection_options, 130 const SharingModelOptions& sharing_options, 131 const FeatureProcessorOptions& feature_processor_options) 132 : EmbeddingNetworkParamsFromImage(start, num_bytes), 133 selection_options_(selection_options), 134 sharing_options_(sharing_options), 135 feature_processor_options_(feature_processor_options), 136 context_size_(feature_processor_options_.context_size()), 137 embedding_params_(std::move(embedding_params)) {} 138 139 SelectionModelOptions selection_options_; 140 SharingModelOptions sharing_options_; 141 FeatureProcessorOptions feature_processor_options_; 142 int context_size_; 143 std::shared_ptr<EmbeddingParams> embedding_params_; 144 }; 145 146 ModelParams* ModelParamsBuilder( 147 const void* start, uint64 num_bytes, 148 std::shared_ptr<EmbeddingParams> external_embedding_params); 149 150 } // namespace libtextclassifier 151 152 #endif // LIBTEXTCLASSIFIER_SMARTSELECT_MODEL_PARAMS_H_ 153