1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
18 #define MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
19
20 #include <memory>
21 #include <vector>
22 #include <map>
23 #include <string>
24 #include <nlohmann/json.hpp>
25 #include "base/base.h"
26 #include "base/base_ref.h"
27 #include "c_api/src/resource_manager.h"
28 #include "include/c_api/ms/context.h"
29 #include "include/c_api/ms/node.h"
30 #include "c_api/src/common.h"
31
32 const std::map<DTypeFormat, std::vector<std::string>> kDTypeFmtEnumToStrMap = {
33 {None_None, {"", ""}},
34 {None_Default, {"", "DefaultFormat"}},
35 {BOOL_None, {"bool", ""}},
36 {BOOL_Default, {"bool", "DefaultFormat"}},
37 {BOOL_5HD, {"bool", "NC1HWC0"}},
38 {BOOL_FracZ, {"bool", "FRACTAL_Z"}},
39 {BOOL_FracNZ, {"bool", "FRACTAL_NZ"}},
40 {BOOL_C1HWNCoC0, {"bool", "C1HWNCoC0"}},
41 {BOOL_NCHW, {"bool", "NCHW"}},
42 {BOOL_NHWC, {"bool", "NHWC"}},
43 {BOOL_HWCN, {"bool", "HWCN"}},
44 {BOOL_NDHWC, {"bool", "NDHWC"}},
45 {BOOL_ChannelLast, {"bool", "ChannelLast"}},
46 {BOOL_Default_Tuple, {"bool", "DefaultFormat", "tuple"}},
47 {BOOL_Default_List, {"bool", "DefaultFormat", "list"}},
48 {I8_None, {"int8", ""}},
49 {I8_Default, {"int8", "DefaultFormat"}},
50 {I8_5HD, {"int8", "NC1HWC0"}},
51 {I8_FracZ, {"int8", "FRACTAL_Z"}},
52 {I8_FracNZ, {"int8", "FRACTAL_NZ"}},
53 {I8_C1HWNCoC0, {"int8", "C1HWNCoC0"}},
54 {I8_NCHW, {"int8", "NCHW"}},
55 {I8_NHWC, {"int8", "NHWC"}},
56 {I8_HWCN, {"int8", "HWCN"}},
57 {I8_NDHWC, {"int8", "NDHWC"}},
58 {I8_NCDHW, {"int8", "NCDHW"}},
59 {I8_ChannelLast, {"int8", "ChannelLast"}},
60 {I8_NDC1HWC0, {"int8", "NDC1HWC0"}},
61 {I8_NC1HWC0, {"int8", "NC1HWC0"}},
62 {I8_Default_Tuple, {"int8", "DefaultFormat", "tuple"}},
63 {I8_Default_List, {"int8", "DefaultFormat", "list"}},
64 {U8_None, {"uint8", ""}},
65 {U8_Default, {"uint8", "DefaultFormat"}},
66 {U8_5HD, {"uint8", "NC1HWC0"}},
67 {U8_FracZ, {"uint8", "FRACTAL_Z"}},
68 {U8_FracNZ, {"uint8", "FRACTAL_NZ"}},
69 {U8_C1HWNCoC0, {"uint8", "C1HWNCoC0"}},
70 {U8_NCHW, {"uint8", "NCHW"}},
71 {U8_NHWC, {"uint8", "NHWC"}},
72 {U8_HWCN, {"uint8", "HWCN"}},
73 {U8_NDHWC, {"uint8", "NDHWC"}},
74 {U8_NCDHW, {"uint8", "NCDHW"}},
75 {U8_ChannelLast, {"uint8", "ChannelLast"}},
76 {U8_NDC1HWC0, {"uint8", "NDC1HWC0"}},
77 {U8_NC1HWC0, {"uint8", "NC1HWC0"}},
78 {U8_Default_Tuple, {"uint8", "DefaultFormat", "tuple"}},
79 {U8_Default_List, {"uint8", "DefaultFormat", "list"}},
80 {I16_None, {"int16", ""}},
81 {I16_Default, {"int16", "DefaultFormat"}},
82 {I16_5HD, {"int16", "NC1HWC0"}},
83 {I16_FracZ, {"int16", "FRACTAL_Z"}},
84 {I16_FracNZ, {"int16", "FRACTAL_NZ"}},
85 {I16_C1HWNCoC0, {"int16", "C1HWNCoC0"}},
86 {I16_NCHW, {"int16", "NCHW"}},
87 {I16_NHWC, {"int16", "NHWC"}},
88 {I16_HWCN, {"int16", "HWCN"}},
89 {I16_NDHWC, {"int16", "NDHWC"}},
90 {I16_ChannelLast, {"int16", "ChannelLast"}},
91 {I16_Default_Tuple, {"int16", "DefaultFormat", "tuple"}},
92 {I16_Default_List, {"int16", "DefaultFormat", "list"}},
93 {U16_None, {"uint16", ""}},
94 {U16_Default, {"uint16", "DefaultFormat"}},
95 {U16_5HD, {"uint16", "NC1HWC0"}},
96 {U16_FracZ, {"uint16", "FRACTAL_Z"}},
97 {U16_FracNZ, {"uint16", "FRACTAL_NZ"}},
98 {U16_C1HWNCoC0, {"uint16", "C1HWNCoC0"}},
99 {U16_NCHW, {"uint16", "NCHW"}},
100 {U16_NHWC, {"uint16", "NHWC"}},
101 {U16_HWCN, {"uint16", "HWCN"}},
102 {U16_NDHWC, {"uint16", "NDHWC"}},
103 {U16_ChannelLast, {"uint16", "ChannelLast"}},
104 {U16_Default_Tuple, {"uint16", "DefaultFormat", "tuple"}},
105 {U16_Default_List, {"uint16", "DefaultFormat", "list"}},
106 {I32_None, {"int32", ""}},
107 {I32_Default, {"int32", "DefaultFormat"}},
108 {I32_5HD, {"int32", "NC1HWC0"}},
109 {I32_FracZ, {"int32", "FRACTAL_Z"}},
110 {I32_FracNZ, {"int32", "FRACTAL_NZ"}},
111 {I32_C1HWNCoC0, {"int32", "C1HWNCoC0"}},
112 {I32_NCHW, {"int32", "NCHW"}},
113 {I32_NHWC, {"int32", "NHWC"}},
114 {I32_HWCN, {"int32", "HWCN"}},
115 {I32_NDHWC, {"int32", "NDHWC"}},
116 {I32_NDC1HWC0, {"int32", "NDC1HWC0"}},
117 {I32_NCDHW, {"int32", "NCDHW"}},
118 {I32_ChannelLast, {"int32", "ChannelLast"}},
119 {I32_Default_Tuple, {"int32", "DefaultFormat", "tuple"}},
120 {I32_Default_List, {"int32", "DefaultFormat", "list"}},
121 {U32_None, {"uint32", ""}},
122 {U32_Default, {"uint32", "DefaultFormat"}},
123 {U32_5HD, {"uint32", "NC1HWC0"}},
124 {U32_FracZ, {"uint32", "FRACTAL_Z"}},
125 {U32_FracNZ, {"uint32", "FRACTAL_NZ"}},
126 {U32_C1HWNCoC0, {"uint32", "C1HWNCoC0"}},
127 {U32_NCHW, {"uint32", "NCHW"}},
128 {U32_NHWC, {"uint32", "NHWC"}},
129 {U32_HWCN, {"uint32", "HWCN"}},
130 {U32_NDHWC, {"uint32", "NDHWC"}},
131 {U32_ChannelLast, {"uint32", "ChannelLast"}},
132 {U32_Default_Tuple, {"uint32", "DefaultFormat", "tuple"}},
133 {U32_Default_List, {"uint32", "DefaultFormat", "list"}},
134 {I64_None, {"int64", ""}},
135 {I64_Default, {"int64", "DefaultFormat"}},
136 {I64_5HD, {"int64", "NC1HWC0"}},
137 {I64_FracZ, {"int64", "FRACTAL_Z"}},
138 {I64_FracNZ, {"int64", "FRACTAL_NZ"}},
139 {I64_C1HWNCoC0, {"int64", "C1HWNCoC0"}},
140 {I64_NCHW, {"int64", "NCHW"}},
141 {I64_NHWC, {"int64", "NHWC"}},
142 {I64_HWCN, {"int64", "HWCN"}},
143 {I64_NDHWC, {"int64", "NDHWC"}},
144 {I64_ChannelLast, {"int64", "ChannelLast"}},
145 {I64_Default_Tuple, {"int64", "DefaultFormat", "tuple"}},
146 {I64_Default_List, {"int64", "DefaultFormat", "list"}},
147 {U64_None, {"uint64", ""}},
148 {U64_Default, {"uint64", "DefaultFormat"}},
149 {U64_5HD, {"uint64", "NC1HWC0"}},
150 {U64_FracZ, {"uint64", "FRACTAL_Z"}},
151 {U64_FracNZ, {"uint64", "FRACTAL_NZ"}},
152 {U64_C1HWNCoC0, {"uint64", "C1HWNCoC0"}},
153 {U64_NCHW, {"uint64", "NCHW"}},
154 {U64_NHWC, {"uint64", "NHWC"}},
155 {U64_HWCN, {"uint64", "HWCN"}},
156 {U64_NDHWC, {"uint64", "NDHWC"}},
157 {U64_ChannelLast, {"uint64", "ChannelLast"}},
158 {U64_Default_Tuple, {"uint64", "DefaultFormat", "tuple"}},
159 {U64_Default_List, {"uint64", "DefaultFormat", "list"}},
160 {F16_None, {"float16", ""}},
161 {F16_Default, {"float16", "DefaultFormat"}},
162 {F16_5HD, {"float16", "NC1HWC0"}},
163 {F16_FracZ, {"float16", "FRACTAL_Z"}},
164 {F16_FracNZ, {"float16", "FRACTAL_NZ"}},
165 {F16_C1HWNCoC0, {"float16", "C1HWNCoC0"}},
166 {F16_NCHW, {"float16", "NCHW"}},
167 {F16_NHWC, {"float16", "NHWC"}},
168 {F16_HWCN, {"float16", "HWCN"}},
169 {F16_NDHWC, {"float16", "NDHWC"}},
170 {F16_NCDHW, {"float16", "NCDHW"}},
171 {F16_DHWCN, {"float16", "DHWCN"}},
172 {F16_NDC1HWC0, {"float16", "NDC1HWC0"}},
173 {F16_FRACTAL_Z_3D, {"float16", "FRACTAL_Z_3D"}},
174 {F16_FracZNLSTM, {"float16", "FRACTAL_ZN_LSTM"}},
175 {F16_FracZNRNN, {"float16", "FRACTAL_ZN_RNN"}},
176 {F16_ND_RNNBIAS, {"float16", "ND_RNN_BIAS"}},
177 {F16_ChannelLast, {"float16", "ChannelLast"}},
178 {F16_Default_Tuple, {"float16", "DefaultFormat", "tuple"}},
179 {F16_Default_List, {"float16", "DefaultFormat", "list"}},
180 {F32_None, {"float32", ""}},
181 {F32_Default, {"float32", "DefaultFormat"}},
182 {F32_5HD, {"float32", "NC1HWC0"}},
183 {F32_FracZ, {"float32", "FRACTAL_Z"}},
184 {F32_FracNZ, {"float32", "FRACTAL_NZ"}},
185 {F32_C1HWNCoC0, {"float32", "C1HWNCoC0"}},
186 {F32_NCHW, {"float32", "NCHW"}},
187 {F32_NHWC, {"float32", "NHWC"}},
188 {F32_HWCN, {"float32", "HWCN"}},
189 {F32_NDHWC, {"float32", "NDHWC"}},
190 {F32_NCDHW, {"float32", "NCDHW"}},
191 {F32_DHWCN, {"float32", "DHWCN"}},
192 {F32_NDC1HWC0, {"float32", "NDC1HWC0"}},
193 {F32_FRACTAL_Z_3D, {"float32", "FRACTAL_Z_3D"}},
194 {F32_FracZNLSTM, {"float32", "FRACTAL_ZN_LSTM"}},
195 {F32_FracZNRNN, {"float32", "FRACTAL_ZN_RNN"}},
196 {F32_ND_RNNBIAS, {"float32", "ND_RNN_BIAS"}},
197 {F32_ChannelLast, {"float32", "ChannelLast"}},
198 {F32_Default_Tuple, {"float32", "DefaultFormat", "tuple"}},
199 {F32_Default_List, {"float32", "DefaultFormat", "list"}},
200 {F64_None, {"float64", ""}},
201 {F64_Default, {"float64", "DefaultFormat"}},
202 {F64_5HD, {"float64", "NC1HWC0"}},
203 {F64_FracZ, {"float64", "FRACTAL_Z"}},
204 {F64_FracNZ, {"float64", "FRACTAL_NZ"}},
205 {F64_C1HWNCoC0, {"float64", "C1HWNCoC0"}},
206 {F64_NCHW, {"float64", "NCHW"}},
207 {F64_NHWC, {"float64", "NHWC"}},
208 {F64_HWCN, {"float64", "HWCN"}},
209 {F64_NDHWC, {"float64", "NDHWC"}},
210 {F64_ChannelLast, {"float64", "ChannelLast"}},
211 {F64_Default_Tuple, {"float64", "DefaultFormat", "tuple"}},
212 {F64_Default_List, {"float64", "DefaultFormat", "list"}},
213 {C64_Default, {"complex64", "DefaultFormat"}},
214 {C128_Default, {"complex128", "DefaultFormat"}},
215 };
216
217 const std::map<std::string, std::string> kOpAttrNameAdaptMap = {
218 {"data_format", "format"},
219 {"group", "groups"},
220 {"transpose_a", "transpose_x1"},
221 {"transpose_b", "transpose_x2"},
222 };
223
224 void ConvertConstScalarInputToTensor(const AnfNodePtr &input_node);
225
226 std::vector<TensorPtr> ConvertOutputToTensor(const mindspore::BaseRef &output);
227
228 STATUS OpSetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, const char *const *attr_names, ValueHandle attrs[],
229 size_t attr_num);
230
231 std::vector<BaseShapePtr> BuildShape(int64_t **out_shapes, size_t *out_dims, size_t out_num);
232
233 std::vector<TypePtr> BuildType(const DataTypeC *out_dtypes, size_t out_num);
234
235 AbstractBasePtr BuildAbstract(std::vector<BaseShapePtr> shapes, std::vector<TypePtr> types);
236
237 AbstractBasePtr GetAbstract(const TypePtr &type, const int64_t shape[], size_t shape_size, bool is_param = false);
238
239 AbstractBasePtr OpInferShapeAndType(const PrimitivePtr &prim, const mindspore::AbstractBasePtrList &args_abs_list);
240
241 STATUS CheckCustomOpInfo(const CustomOpInfo &info);
242
243 nlohmann::json ConvertOpInfoToJson(const CustomOpInfo &info);
244
245 size_t GetMaxMallocSize();
246
247 template <typename T>
GetScalarParam(const FuncGraphPtr & fg,T value,mindspore::TypeId type)248 ParameterPtr GetScalarParam(const FuncGraphPtr &fg, T value, mindspore::TypeId type) {
249 MS_EXCEPTION_IF_NULL(fg);
250 auto param = fg->add_parameter();
251 auto type_ptr = mindspore::TypeIdToType(type);
252 MS_EXCEPTION_IF_NULL(type_ptr);
253 auto tensor = std::make_shared<TensorImpl>(value, type_ptr);
254 tensor->set_param_info(std::make_shared<mindspore::ParamInfo>());
255 param->set_abstract(tensor->ToAbstract());
256 param->set_default_param(tensor);
257 return param;
258 }
259
260 #define MS_ERROR_IF_FALSE_W_RET_N_LOG(condition, val, message) \
261 do { \
262 if (!(condition)) { \
263 MS_LOG(ERROR) << message; \
264 return val; \
265 } \
266 } while (0)
267
268 #define MS_ERROR_IF_TRUE_W_RET_N_LOG(condition, val, message) \
269 do { \
270 if ((condition)) { \
271 MS_LOG(ERROR) << message; \
272 return val; \
273 } \
274 } while (0)
275 #endif // MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
276