• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_
18 #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_
19 
20 #include <map>
21 #include <vector>
22 #include "schema/model_generated.h"
23 #include "nnacl/op_base.h"
24 #include "src/common/common.h"
25 #include "src/common/log_adapter.h"
26 #include "src/common/prim_util.h"
27 #include "src/common/version_manager.h"
28 #include "src/common/utils.h"
29 #include "src/common/log_util.h"
30 
31 namespace mindspore {
32 namespace lite {
33 constexpr int kOffsetTwo = 2;
34 constexpr int kOffsetThree = 3;
35 constexpr size_t kMinShapeSizeTwo = 2;
36 constexpr size_t kMinShapeSizeFour = 4;
37 typedef OpParameter *(*ParameterGen)(const void *prim);
38 
39 static const std::vector<schema::PrimitiveType> string_op = {
40   schema::PrimitiveType_CustomExtractFeatures, schema::PrimitiveType_CustomNormalize,
41   schema::PrimitiveType_CustomPredict,         schema::PrimitiveType_HashtableLookup,
42   schema::PrimitiveType_LshProjection,         schema::PrimitiveType_SkipGram};
43 
44 class PopulateRegistry {
45  public:
46   static PopulateRegistry *GetInstance();
47 
InsertParameterMap(int type,ParameterGen creator,int version)48   void InsertParameterMap(int type, ParameterGen creator, int version) {
49     parameters_[GenPrimVersionKey(type, version)] = creator;
50   }
51 
GetParameterCreator(int type,int version)52   ParameterGen GetParameterCreator(int type, int version) {
53     ParameterGen param_creator = nullptr;
54     auto iter = parameters_.find(GenPrimVersionKey(type, version));
55     if (iter == parameters_.end()) {
56 #ifdef STRING_KERNEL_CLIP
57       if (lite::IsContain(string_op, static_cast<schema::PrimitiveType>(type))) {
58         MS_LOG(ERROR) << unsupport_string_tensor_log;
59         return nullptr;
60       }
61 #endif
62       MS_LOG(ERROR) << "Unsupported parameter type in Create : "
63                     << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
64       return nullptr;
65     }
66     param_creator = iter->second;
67     return param_creator;
68   }
69 
70  protected:
71   // key:type * 1000 + schema_version
72   std::map<int, ParameterGen> parameters_;
73 };
74 
75 class Registry {
76  public:
Registry(int primitive_type,ParameterGen creator,int version)77   Registry(int primitive_type, ParameterGen creator, int version) noexcept {
78     PopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator, version);
79   }
80   ~Registry() = default;
81 };
82 
83 #define REG_POPULATE(primitive_type, creator, version) \
84   static Registry g_##primitive_type##version(primitive_type, creator, version);
85 
86 }  // namespace lite
87 }  // namespace mindspore
88 #endif  // MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_POPULATE_REGISTER_H_
89