• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_
18 #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_
19 
20 #include <map>
21 #include <vector>
22 #include <string>
23 
24 #include "schema/model_generated.h"
25 #include "nnacl/op_base.h"
26 #include "src/common/common.h"
27 #include "src/common/log_adapter.h"
28 #include "src/common/prim_util.h"
29 #include "src/common/version_manager.h"
30 #include "src/common/utils.h"
31 #include "src/common/log_util.h"
32 
33 namespace mindspore {
34 constexpr int kOffsetTwo = 2;
35 constexpr int kOffsetThree = 3;
36 constexpr size_t kMinShapeSizeTwo = 2;
37 constexpr size_t kMinShapeSizeFour = 4;
38 typedef OpParameter *(*BaseOperator2Parameter)(void *base_operator);
39 
40 static const std::vector<schema::PrimitiveType> string_op = {
41   schema::PrimitiveType_CustomExtractFeatures, schema::PrimitiveType_CustomNormalize,
42   schema::PrimitiveType_CustomPredict,         schema::PrimitiveType_HashtableLookup,
43   schema::PrimitiveType_LshProjection,         schema::PrimitiveType_SkipGram};
44 
45 class BaseOperatorPopulateRegistry {
46  public:
47   static BaseOperatorPopulateRegistry *GetInstance();
48 
49   void InsertParameterMap(int type, BaseOperator2Parameter creator, int version = lite::SCHEMA_CUR) {
50     parameters_[lite::GenPrimVersionKey(type, version)] = creator;
51     std::string str = schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
52     str_to_type_map_[str] = type;
53   }
54 
55   BaseOperator2Parameter GetParameterCreator(int type, int version = lite::SCHEMA_CUR) {
56     BaseOperator2Parameter param_creator = nullptr;
57     auto iter = parameters_.find(lite::GenPrimVersionKey(type, version));
58     if (iter == parameters_.end()) {
59 #ifdef STRING_KERNEL_CLIP
60       if (lite::IsContain(string_op, static_cast<schema::PrimitiveType>(type))) {
61         MS_LOG(ERROR) << unsupport_string_tensor_log;
62         return nullptr;
63       }
64 #endif
65       MS_LOG(ERROR) << "Unsupported parameter type in Create : "
66                     << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
67       return nullptr;
68     }
69     param_creator = iter->second;
70     return param_creator;
71   }
72 
TypeStrToType(const std::string & type_str)73   int TypeStrToType(const std::string &type_str) {
74     auto iter = str_to_type_map_.find(type_str);
75     if (iter == str_to_type_map_.end()) {
76       MS_LOG(ERROR) << "Unknown type string to type " << type_str;
77       return schema::PrimitiveType_NONE;
78     }
79     return iter->second;
80   }
81 
82  protected:
83   // key:type * 1000 + schema_version
84   std::map<int, BaseOperator2Parameter> parameters_;
85   std::map<std::string, int> str_to_type_map_;
86 };
87 
88 class BaseRegistry {
89  public:
90   BaseRegistry(int primitive_type, BaseOperator2Parameter creator, int version = lite::SCHEMA_CUR) noexcept {
91     BaseOperatorPopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator, version);
92   }
93   ~BaseRegistry() = default;
94 };
95 
96 #define REG_BASE_POPULATE(primitive_type, creator) \
97   static BaseRegistry g_##primitive_type##base_populate(primitive_type, creator);
98 }  // namespace mindspore
99 #endif  // MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_
100