• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/common/prim_util.h"
18 #include <set>
19 #include <vector>
20 #include "nnacl/op_base.h"
21 #include "src/common/log_util.h"
22 #include "schema/model_generated.h"
23 #include "src/common/log_adapter.h"
24 
25 namespace mindspore {
26 namespace lite {
27 static std::set<schema::PrimitiveType> kTensorListOps = {
28   schema::PrimitiveType_TensorListFromTensor, schema::PrimitiveType_TensorListGetItem,
29   schema::PrimitiveType_TensorListReserve, schema::PrimitiveType_TensorListSetItem,
30   schema::PrimitiveType_TensorListStack};
31 static const char *const kInnerOpNames[C20NUM] = {"Inner_ToFormat",           "Inner_GltextureToOpencl",
32                                                   "Inner_Identity",           "Inner_ShapeFusion",
33                                                   "Inner_GraphKernel",        "Inner_SplitReduceConcatFusion",
34                                                   "Inner_EncoderLayer",       "Inner_FseDecode",
35                                                   "Inner_DecoderLayer",       "Inner_UsePastEmbedding",
36                                                   "Inner_CustomGru",          "Inner_CastGatherReduceFusion",
37                                                   "Inner_ReduceConcatFusion", "Inner_AclCustomOp",
38                                                   "Inner_CustomMaskedFill",   "Inner_CustomTensorScatterMax",
39                                                   "Inner_CustomIsInf",        "Inner_CustomGatherDGradV2",
40                                                   "Inner_ThirdPartyModel"};
GetPrimitiveType(const void * primitive,int schema_version)41 int GetPrimitiveType(const void *primitive, int schema_version) {
42   if (primitive == nullptr) {
43     return -1;
44   }
45   return static_cast<const schema::Primitive *>(primitive)->value_type();
46 }
47 
GetPrimitiveTypeName(const void * primitive,int schema_version)48 const char *GetPrimitiveTypeName(const void *primitive, int schema_version) {
49   if (primitive == nullptr) {
50     return "NONE";
51   }
52   return schema::EnumNamePrimitiveType(static_cast<const schema::Primitive *>(primitive)->value_type());
53 }
54 
PrimitiveCurVersionTypeName(int type)55 const char *PrimitiveCurVersionTypeName(int type) {
56   if (type >= static_cast<int>(schema::PrimitiveType_MIN) && type < static_cast<int>(schema::PrimitiveType_MAX)) {
57     return schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
58   } else if (type >= static_cast<int>(schema::PrimitiveType_MAX)) {
59     if (type >= PrimType_InnerOpMin && type < PrimType_InnerOpMax) {
60       return kInnerOpNames[type - PrimType_InnerOpMin];
61     }
62   }
63   return "";
64 }
65 
GenPrimVersionKey(int primitive_type,int schema_version)66 int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }
67 
IsPartialNode(const void * primitive,int schema_version)68 bool IsPartialNode(const void *primitive, int schema_version) {
69   MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
70   if (schema_version == SCHEMA_CUR) {
71     return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion;
72   }
73   return false;
74 }
75 
IsCallNode(const void * primitive,int schema_version)76 bool IsCallNode(const void *primitive, int schema_version) {
77   MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
78   if (schema_version == SCHEMA_CUR) {
79     return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Call;
80   }
81   return false;
82 }
83 
IsSwitchNode(const void * primitive,int schema_version)84 bool IsSwitchNode(const void *primitive, int schema_version) {
85   MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
86   if (schema_version == SCHEMA_CUR) {
87     return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Switch;
88   }
89   return false;
90 }
91 
IsSwitchLayerNode(const void * primitive,int schema_version)92 bool IsSwitchLayerNode(const void *primitive, int schema_version) {
93   MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
94   if (schema_version == SCHEMA_CUR) {
95     return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_SwitchLayer;
96   }
97   return false;
98 }
99 
IsCustomNode(const void * primitive,int schema_version)100 bool IsCustomNode(const void *primitive, int schema_version) {
101   MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
102   if (schema_version == SCHEMA_CUR) {
103     return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Custom;
104   }
105   return false;
106 }
107 
IsTensorListNode(const void * primitive,int schema_version)108 bool IsTensorListNode(const void *primitive, int schema_version) {
109   MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
110   if (schema_version == SCHEMA_CUR) {
111     if (kTensorListOps.find(reinterpret_cast<const schema::Primitive *>(primitive)->value_type()) !=
112         kTensorListOps.end()) {
113       return true;
114     }
115   }
116   return false;
117 }
118 
GetPartialGraphIndex(const void * primitive,int schema_version)119 int GetPartialGraphIndex(const void *primitive, int schema_version) {
120   MS_CHECK_TRUE_MSG(primitive != nullptr, -1, "primtive cannot be nullptr");
121   int index = -1;
122   if (schema_version == SCHEMA_CUR) {
123     auto partial_fusion = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion();
124     if (partial_fusion == nullptr) {
125       return -1;
126     }
127     index = partial_fusion->sub_graph_index();
128   }
129   return index;
130 }
IsSharedThreadPoolOp(int op_type)131 bool IsSharedThreadPoolOp(int op_type) {
132   std::vector<schema::PrimitiveType> shared_ops = {mindspore::schema::PrimitiveType_MatMulFusion};
133   if (find(shared_ops.begin(), shared_ops.end(), op_type) != shared_ops.end()) {
134     return true;
135   }
136   return false;
137 }
138 }  // namespace lite
139 }  // namespace mindspore
140