1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/common/prim_util.h"
18 #include <set>
19 #include <vector>
20 #include "nnacl/op_base.h"
21 #include "src/common/log_util.h"
22 #include "schema/model_generated.h"
23 #include "src/common/log_adapter.h"
24
25 namespace mindspore {
26 namespace lite {
27 static std::set<schema::PrimitiveType> kTensorListOps = {
28 schema::PrimitiveType_TensorListFromTensor, schema::PrimitiveType_TensorListGetItem,
29 schema::PrimitiveType_TensorListReserve, schema::PrimitiveType_TensorListSetItem,
30 schema::PrimitiveType_TensorListStack};
31 static const char *const kInnerOpNames[C20NUM] = {"Inner_ToFormat", "Inner_GltextureToOpencl",
32 "Inner_Identity", "Inner_ShapeFusion",
33 "Inner_GraphKernel", "Inner_SplitReduceConcatFusion",
34 "Inner_EncoderLayer", "Inner_FseDecode",
35 "Inner_DecoderLayer", "Inner_UsePastEmbedding",
36 "Inner_CustomGru", "Inner_CastGatherReduceFusion",
37 "Inner_ReduceConcatFusion", "Inner_AclCustomOp",
38 "Inner_CustomMaskedFill", "Inner_CustomTensorScatterMax",
39 "Inner_CustomIsInf", "Inner_CustomGatherDGradV2",
40 "Inner_ThirdPartyModel"};
GetPrimitiveType(const void * primitive,int schema_version)41 int GetPrimitiveType(const void *primitive, int schema_version) {
42 if (primitive == nullptr) {
43 return -1;
44 }
45 return static_cast<const schema::Primitive *>(primitive)->value_type();
46 }
47
GetPrimitiveTypeName(const void * primitive,int schema_version)48 const char *GetPrimitiveTypeName(const void *primitive, int schema_version) {
49 if (primitive == nullptr) {
50 return "NONE";
51 }
52 return schema::EnumNamePrimitiveType(static_cast<const schema::Primitive *>(primitive)->value_type());
53 }
54
PrimitiveCurVersionTypeName(int type)55 const char *PrimitiveCurVersionTypeName(int type) {
56 if (type >= static_cast<int>(schema::PrimitiveType_MIN) && type < static_cast<int>(schema::PrimitiveType_MAX)) {
57 return schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(type));
58 } else if (type >= static_cast<int>(schema::PrimitiveType_MAX)) {
59 if (type >= PrimType_InnerOpMin && type < PrimType_InnerOpMax) {
60 return kInnerOpNames[type - PrimType_InnerOpMin];
61 }
62 }
63 return "";
64 }
65
GenPrimVersionKey(int primitive_type,int schema_version)66 int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }
67
IsPartialNode(const void * primitive,int schema_version)68 bool IsPartialNode(const void *primitive, int schema_version) {
69 MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
70 if (schema_version == SCHEMA_CUR) {
71 return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion;
72 }
73 return false;
74 }
75
IsCallNode(const void * primitive,int schema_version)76 bool IsCallNode(const void *primitive, int schema_version) {
77 MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
78 if (schema_version == SCHEMA_CUR) {
79 return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Call;
80 }
81 return false;
82 }
83
IsSwitchNode(const void * primitive,int schema_version)84 bool IsSwitchNode(const void *primitive, int schema_version) {
85 MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
86 if (schema_version == SCHEMA_CUR) {
87 return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Switch;
88 }
89 return false;
90 }
91
IsSwitchLayerNode(const void * primitive,int schema_version)92 bool IsSwitchLayerNode(const void *primitive, int schema_version) {
93 MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
94 if (schema_version == SCHEMA_CUR) {
95 return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_SwitchLayer;
96 }
97 return false;
98 }
99
IsCustomNode(const void * primitive,int schema_version)100 bool IsCustomNode(const void *primitive, int schema_version) {
101 MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
102 if (schema_version == SCHEMA_CUR) {
103 return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_Custom;
104 }
105 return false;
106 }
107
IsTensorListNode(const void * primitive,int schema_version)108 bool IsTensorListNode(const void *primitive, int schema_version) {
109 MS_CHECK_TRUE_MSG(primitive != nullptr, false, "primtive cannot be nullptr");
110 if (schema_version == SCHEMA_CUR) {
111 if (kTensorListOps.find(reinterpret_cast<const schema::Primitive *>(primitive)->value_type()) !=
112 kTensorListOps.end()) {
113 return true;
114 }
115 }
116 return false;
117 }
118
GetPartialGraphIndex(const void * primitive,int schema_version)119 int GetPartialGraphIndex(const void *primitive, int schema_version) {
120 MS_CHECK_TRUE_MSG(primitive != nullptr, -1, "primtive cannot be nullptr");
121 int index = -1;
122 if (schema_version == SCHEMA_CUR) {
123 auto partial_fusion = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion();
124 if (partial_fusion == nullptr) {
125 return -1;
126 }
127 index = partial_fusion->sub_graph_index();
128 }
129 return index;
130 }
IsSharedThreadPoolOp(int op_type)131 bool IsSharedThreadPoolOp(int op_type) {
132 std::vector<schema::PrimitiveType> shared_ops = {mindspore::schema::PrimitiveType_MatMulFusion};
133 if (find(shared_ops.begin(), shared_ops.end(), op_type) != shared_ops.end()) {
134 return true;
135 }
136 return false;
137 }
138 } // namespace lite
139 } // namespace mindspore
140