1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_ 18 #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_ 19 20 #include <map> 21 #include <vector> 22 #include "schema/model_generated.h" 23 #include "nnacl/op_base.h" 24 #include "src/common/common.h" 25 #include "src/common/log_adapter.h" 26 #include "src/common/prim_util.h" 27 #include "src/common/version_manager.h" 28 #include "src/common/utils.h" 29 #include "src/common/log_util.h" 30 31 namespace mindspore { 32 namespace lite { 33 constexpr int kOffsetTwo = 2; 34 constexpr int kOffsetThree = 3; 35 constexpr size_t kMinShapeSizeTwo = 2; 36 constexpr size_t kMinShapeSizeFour = 4; 37 typedef OpParameter *(*ParameterGen)(const void *prim); 38 39 static const std::vector<schema::PrimitiveType> string_op = { 40 schema::PrimitiveType_CustomExtractFeatures, schema::PrimitiveType_CustomNormalize, 41 schema::PrimitiveType_CustomPredict, schema::PrimitiveType_HashtableLookup, 42 schema::PrimitiveType_LshProjection, schema::PrimitiveType_SkipGram}; 43 44 class PopulateRegistry { 45 public: 46 static PopulateRegistry *GetInstance(); 47 InsertParameterMap(int type,ParameterGen creator,int version)48 void InsertParameterMap(int type, ParameterGen creator, int version) { 49 parameters_[GenPrimVersionKey(type, version)] = creator; 50 } 51 GetParameterCreator(int type,int version)52 ParameterGen GetParameterCreator(int type, int version) { 53 ParameterGen param_creator = nullptr; 54 auto iter = parameters_.find(GenPrimVersionKey(type, version)); 55 if (iter == parameters_.end()) { 56 #ifdef STRING_KERNEL_CLIP 57 if (lite::IsContain(string_op, static_cast<schema::PrimitiveType>(type))) { 58 MS_LOG(ERROR) << unsupport_string_tensor_log; 59 return nullptr; 60 } 61 #endif 62 MS_LOG(ERROR) << "Unsupported parameter type in Create : " 63 << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type)); 64 return nullptr; 65 } 66 param_creator = iter->second; 67 return param_creator; 68 } 69 70 protected: 71 // key:type * 1000 + schema_version 72 std::map<int, ParameterGen> parameters_; 73 }; 74 75 class Registry { 76 public: Registry(int primitive_type,ParameterGen creator,int version)77 Registry(int primitive_type, ParameterGen creator, int version) noexcept { 78 PopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator, version); 79 } 80 ~Registry() = default; 81 }; 82 83 #define REG_POPULATE(primitive_type, creator, version) \ 84 static Registry g_##primitive_type##version(primitive_type, creator, version); 85 86 } // namespace lite 87 } // namespace mindspore 88 #endif // MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_ 89