1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_ 18 #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_ 19 20 #include <map> 21 #include <vector> 22 #include <string> 23 24 #include "schema/model_generated.h" 25 #include "nnacl/op_base.h" 26 #include "src/common/common.h" 27 #include "src/common/log_adapter.h" 28 #include "src/common/prim_util.h" 29 #include "src/common/version_manager.h" 30 #include "src/common/utils.h" 31 #include "src/common/log_util.h" 32 33 namespace mindspore { 34 constexpr int kOffsetTwo = 2; 35 constexpr int kOffsetThree = 3; 36 constexpr size_t kMinShapeSizeTwo = 2; 37 constexpr size_t kMinShapeSizeFour = 4; 38 typedef OpParameter *(*BaseOperator2Parameter)(void *base_operator); 39 40 static const std::vector<schema::PrimitiveType> string_op = { 41 schema::PrimitiveType_CustomExtractFeatures, schema::PrimitiveType_CustomNormalize, 42 schema::PrimitiveType_CustomPredict, schema::PrimitiveType_HashtableLookup, 43 schema::PrimitiveType_LshProjection, schema::PrimitiveType_SkipGram}; 44 45 class BaseOperatorPopulateRegistry { 46 public: 47 static BaseOperatorPopulateRegistry *GetInstance(); 48 49 void InsertParameterMap(int type, BaseOperator2Parameter creator, int version = lite::SCHEMA_CUR) { 50 parameters_[lite::GenPrimVersionKey(type, version)] = creator; 51 std::string str = schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type)); 52 str_to_type_map_[str] = type; 53 } 54 55 BaseOperator2Parameter GetParameterCreator(int type, int version = lite::SCHEMA_CUR) { 56 BaseOperator2Parameter param_creator = nullptr; 57 auto iter = parameters_.find(lite::GenPrimVersionKey(type, version)); 58 if (iter == parameters_.end()) { 59 #ifdef STRING_KERNEL_CLIP 60 if (lite::IsContain(string_op, static_cast<schema::PrimitiveType>(type))) { 61 MS_LOG(ERROR) << unsupport_string_tensor_log; 62 return nullptr; 63 } 64 #endif 65 MS_LOG(ERROR) << "Unsupported parameter type in Create : " 66 << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type)); 67 return nullptr; 68 } 69 param_creator = iter->second; 70 return param_creator; 71 } 72 TypeStrToType(const std::string & type_str)73 int TypeStrToType(const std::string &type_str) { 74 auto iter = str_to_type_map_.find(type_str); 75 if (iter == str_to_type_map_.end()) { 76 MS_LOG(ERROR) << "Unknown type string to type " << type_str; 77 return schema::PrimitiveType_NONE; 78 } 79 return iter->second; 80 } 81 82 protected: 83 // key:type * 1000 + schema_version 84 std::map<int, BaseOperator2Parameter> parameters_; 85 std::map<std::string, int> str_to_type_map_; 86 }; 87 88 class BaseRegistry { 89 public: 90 BaseRegistry(int primitive_type, BaseOperator2Parameter creator, int version = lite::SCHEMA_CUR) noexcept { 91 BaseOperatorPopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator, version); 92 } 93 ~BaseRegistry() = default; 94 }; 95 96 #define REG_BASE_POPULATE(primitive_type, creator) \ 97 static BaseRegistry g_##primitive_type##base_populate(primitive_type, creator); 98 } // namespace mindspore 99 #endif // MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_ 100