• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
18 #define MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <map>
23 #include <string>
24 #include <nlohmann/json.hpp>
25 #include "base/base.h"
26 #include "base/base_ref.h"
27 #include "c_api/src/resource_manager.h"
28 #include "include/c_api/ms/context.h"
29 #include "include/c_api/ms/node.h"
30 #include "c_api/src/common.h"
31 
32 const std::map<DTypeFormat, std::vector<std::string>> kDTypeFmtEnumToStrMap = {
33   {None_None, {"", ""}},
34   {None_Default, {"", "DefaultFormat"}},
35   {BOOL_None, {"bool", ""}},
36   {BOOL_Default, {"bool", "DefaultFormat"}},
37   {BOOL_5HD, {"bool", "NC1HWC0"}},
38   {BOOL_FracZ, {"bool", "FRACTAL_Z"}},
39   {BOOL_FracNZ, {"bool", "FRACTAL_NZ"}},
40   {BOOL_C1HWNCoC0, {"bool", "C1HWNCoC0"}},
41   {BOOL_NCHW, {"bool", "NCHW"}},
42   {BOOL_NHWC, {"bool", "NHWC"}},
43   {BOOL_HWCN, {"bool", "HWCN"}},
44   {BOOL_NDHWC, {"bool", "NDHWC"}},
45   {BOOL_ChannelLast, {"bool", "ChannelLast"}},
46   {BOOL_Default_Tuple, {"bool", "DefaultFormat", "tuple"}},
47   {BOOL_Default_List, {"bool", "DefaultFormat", "list"}},
48   {I8_None, {"int8", ""}},
49   {I8_Default, {"int8", "DefaultFormat"}},
50   {I8_5HD, {"int8", "NC1HWC0"}},
51   {I8_FracZ, {"int8", "FRACTAL_Z"}},
52   {I8_FracNZ, {"int8", "FRACTAL_NZ"}},
53   {I8_C1HWNCoC0, {"int8", "C1HWNCoC0"}},
54   {I8_NCHW, {"int8", "NCHW"}},
55   {I8_NHWC, {"int8", "NHWC"}},
56   {I8_HWCN, {"int8", "HWCN"}},
57   {I8_NDHWC, {"int8", "NDHWC"}},
58   {I8_NCDHW, {"int8", "NCDHW"}},
59   {I8_ChannelLast, {"int8", "ChannelLast"}},
60   {I8_NDC1HWC0, {"int8", "NDC1HWC0"}},
61   {I8_NC1HWC0, {"int8", "NC1HWC0"}},
62   {I8_Default_Tuple, {"int8", "DefaultFormat", "tuple"}},
63   {I8_Default_List, {"int8", "DefaultFormat", "list"}},
64   {U8_None, {"uint8", ""}},
65   {U8_Default, {"uint8", "DefaultFormat"}},
66   {U8_5HD, {"uint8", "NC1HWC0"}},
67   {U8_FracZ, {"uint8", "FRACTAL_Z"}},
68   {U8_FracNZ, {"uint8", "FRACTAL_NZ"}},
69   {U8_C1HWNCoC0, {"uint8", "C1HWNCoC0"}},
70   {U8_NCHW, {"uint8", "NCHW"}},
71   {U8_NHWC, {"uint8", "NHWC"}},
72   {U8_HWCN, {"uint8", "HWCN"}},
73   {U8_NDHWC, {"uint8", "NDHWC"}},
74   {U8_NCDHW, {"uint8", "NCDHW"}},
75   {U8_ChannelLast, {"uint8", "ChannelLast"}},
76   {U8_NDC1HWC0, {"uint8", "NDC1HWC0"}},
77   {U8_NC1HWC0, {"uint8", "NC1HWC0"}},
78   {U8_Default_Tuple, {"uint8", "DefaultFormat", "tuple"}},
79   {U8_Default_List, {"uint8", "DefaultFormat", "list"}},
80   {I16_None, {"int16", ""}},
81   {I16_Default, {"int16", "DefaultFormat"}},
82   {I16_5HD, {"int16", "NC1HWC0"}},
83   {I16_FracZ, {"int16", "FRACTAL_Z"}},
84   {I16_FracNZ, {"int16", "FRACTAL_NZ"}},
85   {I16_C1HWNCoC0, {"int16", "C1HWNCoC0"}},
86   {I16_NCHW, {"int16", "NCHW"}},
87   {I16_NHWC, {"int16", "NHWC"}},
88   {I16_HWCN, {"int16", "HWCN"}},
89   {I16_NDHWC, {"int16", "NDHWC"}},
90   {I16_ChannelLast, {"int16", "ChannelLast"}},
91   {I16_Default_Tuple, {"int16", "DefaultFormat", "tuple"}},
92   {I16_Default_List, {"int16", "DefaultFormat", "list"}},
93   {U16_None, {"uint16", ""}},
94   {U16_Default, {"uint16", "DefaultFormat"}},
95   {U16_5HD, {"uint16", "NC1HWC0"}},
96   {U16_FracZ, {"uint16", "FRACTAL_Z"}},
97   {U16_FracNZ, {"uint16", "FRACTAL_NZ"}},
98   {U16_C1HWNCoC0, {"uint16", "C1HWNCoC0"}},
99   {U16_NCHW, {"uint16", "NCHW"}},
100   {U16_NHWC, {"uint16", "NHWC"}},
101   {U16_HWCN, {"uint16", "HWCN"}},
102   {U16_NDHWC, {"uint16", "NDHWC"}},
103   {U16_ChannelLast, {"uint16", "ChannelLast"}},
104   {U16_Default_Tuple, {"uint16", "DefaultFormat", "tuple"}},
105   {U16_Default_List, {"uint16", "DefaultFormat", "list"}},
106   {I32_None, {"int32", ""}},
107   {I32_Default, {"int32", "DefaultFormat"}},
108   {I32_5HD, {"int32", "NC1HWC0"}},
109   {I32_FracZ, {"int32", "FRACTAL_Z"}},
110   {I32_FracNZ, {"int32", "FRACTAL_NZ"}},
111   {I32_C1HWNCoC0, {"int32", "C1HWNCoC0"}},
112   {I32_NCHW, {"int32", "NCHW"}},
113   {I32_NHWC, {"int32", "NHWC"}},
114   {I32_HWCN, {"int32", "HWCN"}},
115   {I32_NDHWC, {"int32", "NDHWC"}},
116   {I32_NDC1HWC0, {"int32", "NDC1HWC0"}},
117   {I32_NCDHW, {"int32", "NCDHW"}},
118   {I32_ChannelLast, {"int32", "ChannelLast"}},
119   {I32_Default_Tuple, {"int32", "DefaultFormat", "tuple"}},
120   {I32_Default_List, {"int32", "DefaultFormat", "list"}},
121   {U32_None, {"uint32", ""}},
122   {U32_Default, {"uint32", "DefaultFormat"}},
123   {U32_5HD, {"uint32", "NC1HWC0"}},
124   {U32_FracZ, {"uint32", "FRACTAL_Z"}},
125   {U32_FracNZ, {"uint32", "FRACTAL_NZ"}},
126   {U32_C1HWNCoC0, {"uint32", "C1HWNCoC0"}},
127   {U32_NCHW, {"uint32", "NCHW"}},
128   {U32_NHWC, {"uint32", "NHWC"}},
129   {U32_HWCN, {"uint32", "HWCN"}},
130   {U32_NDHWC, {"uint32", "NDHWC"}},
131   {U32_ChannelLast, {"uint32", "ChannelLast"}},
132   {U32_Default_Tuple, {"uint32", "DefaultFormat", "tuple"}},
133   {U32_Default_List, {"uint32", "DefaultFormat", "list"}},
134   {I64_None, {"int64", ""}},
135   {I64_Default, {"int64", "DefaultFormat"}},
136   {I64_5HD, {"int64", "NC1HWC0"}},
137   {I64_FracZ, {"int64", "FRACTAL_Z"}},
138   {I64_FracNZ, {"int64", "FRACTAL_NZ"}},
139   {I64_C1HWNCoC0, {"int64", "C1HWNCoC0"}},
140   {I64_NCHW, {"int64", "NCHW"}},
141   {I64_NHWC, {"int64", "NHWC"}},
142   {I64_HWCN, {"int64", "HWCN"}},
143   {I64_NDHWC, {"int64", "NDHWC"}},
144   {I64_ChannelLast, {"int64", "ChannelLast"}},
145   {I64_Default_Tuple, {"int64", "DefaultFormat", "tuple"}},
146   {I64_Default_List, {"int64", "DefaultFormat", "list"}},
147   {U64_None, {"uint64", ""}},
148   {U64_Default, {"uint64", "DefaultFormat"}},
149   {U64_5HD, {"uint64", "NC1HWC0"}},
150   {U64_FracZ, {"uint64", "FRACTAL_Z"}},
151   {U64_FracNZ, {"uint64", "FRACTAL_NZ"}},
152   {U64_C1HWNCoC0, {"uint64", "C1HWNCoC0"}},
153   {U64_NCHW, {"uint64", "NCHW"}},
154   {U64_NHWC, {"uint64", "NHWC"}},
155   {U64_HWCN, {"uint64", "HWCN"}},
156   {U64_NDHWC, {"uint64", "NDHWC"}},
157   {U64_ChannelLast, {"uint64", "ChannelLast"}},
158   {U64_Default_Tuple, {"uint64", "DefaultFormat", "tuple"}},
159   {U64_Default_List, {"uint64", "DefaultFormat", "list"}},
160   {F16_None, {"float16", ""}},
161   {F16_Default, {"float16", "DefaultFormat"}},
162   {F16_5HD, {"float16", "NC1HWC0"}},
163   {F16_FracZ, {"float16", "FRACTAL_Z"}},
164   {F16_FracNZ, {"float16", "FRACTAL_NZ"}},
165   {F16_C1HWNCoC0, {"float16", "C1HWNCoC0"}},
166   {F16_NCHW, {"float16", "NCHW"}},
167   {F16_NHWC, {"float16", "NHWC"}},
168   {F16_HWCN, {"float16", "HWCN"}},
169   {F16_NDHWC, {"float16", "NDHWC"}},
170   {F16_NCDHW, {"float16", "NCDHW"}},
171   {F16_DHWCN, {"float16", "DHWCN"}},
172   {F16_NDC1HWC0, {"float16", "NDC1HWC0"}},
173   {F16_FRACTAL_Z_3D, {"float16", "FRACTAL_Z_3D"}},
174   {F16_FracZNLSTM, {"float16", "FRACTAL_ZN_LSTM"}},
175   {F16_FracZNRNN, {"float16", "FRACTAL_ZN_RNN"}},
176   {F16_ND_RNNBIAS, {"float16", "ND_RNN_BIAS"}},
177   {F16_ChannelLast, {"float16", "ChannelLast"}},
178   {F16_Default_Tuple, {"float16", "DefaultFormat", "tuple"}},
179   {F16_Default_List, {"float16", "DefaultFormat", "list"}},
180   {F32_None, {"float32", ""}},
181   {F32_Default, {"float32", "DefaultFormat"}},
182   {F32_5HD, {"float32", "NC1HWC0"}},
183   {F32_FracZ, {"float32", "FRACTAL_Z"}},
184   {F32_FracNZ, {"float32", "FRACTAL_NZ"}},
185   {F32_C1HWNCoC0, {"float32", "C1HWNCoC0"}},
186   {F32_NCHW, {"float32", "NCHW"}},
187   {F32_NHWC, {"float32", "NHWC"}},
188   {F32_HWCN, {"float32", "HWCN"}},
189   {F32_NDHWC, {"float32", "NDHWC"}},
190   {F32_NCDHW, {"float32", "NCDHW"}},
191   {F32_DHWCN, {"float32", "DHWCN"}},
192   {F32_NDC1HWC0, {"float32", "NDC1HWC0"}},
193   {F32_FRACTAL_Z_3D, {"float32", "FRACTAL_Z_3D"}},
194   {F32_FracZNLSTM, {"float32", "FRACTAL_ZN_LSTM"}},
195   {F32_FracZNRNN, {"float32", "FRACTAL_ZN_RNN"}},
196   {F32_ND_RNNBIAS, {"float32", "ND_RNN_BIAS"}},
197   {F32_ChannelLast, {"float32", "ChannelLast"}},
198   {F32_Default_Tuple, {"float32", "DefaultFormat", "tuple"}},
199   {F32_Default_List, {"float32", "DefaultFormat", "list"}},
200   {F64_None, {"float64", ""}},
201   {F64_Default, {"float64", "DefaultFormat"}},
202   {F64_5HD, {"float64", "NC1HWC0"}},
203   {F64_FracZ, {"float64", "FRACTAL_Z"}},
204   {F64_FracNZ, {"float64", "FRACTAL_NZ"}},
205   {F64_C1HWNCoC0, {"float64", "C1HWNCoC0"}},
206   {F64_NCHW, {"float64", "NCHW"}},
207   {F64_NHWC, {"float64", "NHWC"}},
208   {F64_HWCN, {"float64", "HWCN"}},
209   {F64_NDHWC, {"float64", "NDHWC"}},
210   {F64_ChannelLast, {"float64", "ChannelLast"}},
211   {F64_Default_Tuple, {"float64", "DefaultFormat", "tuple"}},
212   {F64_Default_List, {"float64", "DefaultFormat", "list"}},
213   {C64_Default, {"complex64", "DefaultFormat"}},
214   {C128_Default, {"complex128", "DefaultFormat"}},
215 };
216 
217 const std::map<std::string, std::string> kOpAttrNameAdaptMap = {
218   {"data_format", "format"},
219   {"group", "groups"},
220   {"transpose_a", "transpose_x1"},
221   {"transpose_b", "transpose_x2"},
222 };
223 
224 void ConvertConstScalarInputToTensor(const AnfNodePtr &input_node);
225 
226 std::vector<TensorPtr> ConvertOutputToTensor(const mindspore::BaseRef &output);
227 
228 STATUS OpSetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, const char *const *attr_names, ValueHandle attrs[],
229                   size_t attr_num);
230 
231 std::vector<BaseShapePtr> BuildShape(int64_t **out_shapes, size_t *out_dims, size_t out_num);
232 
233 std::vector<TypePtr> BuildType(const DataTypeC *out_dtypes, size_t out_num);
234 
235 AbstractBasePtr BuildAbstract(std::vector<BaseShapePtr> shapes, std::vector<TypePtr> types);
236 
237 AbstractBasePtr GetAbstract(const TypePtr &type, const int64_t shape[], size_t shape_size, bool is_param = false);
238 
239 AbstractBasePtr OpInferShapeAndType(const PrimitivePtr &prim, const mindspore::AbstractBasePtrList &args_abs_list);
240 
241 STATUS CheckCustomOpInfo(const CustomOpInfo &info);
242 
243 nlohmann::json ConvertOpInfoToJson(const CustomOpInfo &info);
244 
245 size_t GetMaxMallocSize();
246 
247 template <typename T>
GetScalarParam(const FuncGraphPtr & fg,T value,mindspore::TypeId type)248 ParameterPtr GetScalarParam(const FuncGraphPtr &fg, T value, mindspore::TypeId type) {
249   MS_EXCEPTION_IF_NULL(fg);
250   auto param = fg->add_parameter();
251   auto type_ptr = mindspore::TypeIdToType(type);
252   MS_EXCEPTION_IF_NULL(type_ptr);
253   auto tensor = std::make_shared<TensorImpl>(value, type_ptr);
254   tensor->set_param_info(std::make_shared<mindspore::ParamInfo>());
255   param->set_abstract(tensor->ToAbstract());
256   param->set_default_param(tensor);
257   return param;
258 }
259 
260 #define MS_ERROR_IF_FALSE_W_RET_N_LOG(condition, val, message) \
261   do {                                                         \
262     if (!(condition)) {                                        \
263       MS_LOG(ERROR) << message;                                \
264       return val;                                              \
265     }                                                          \
266   } while (0)
267 
268 #define MS_ERROR_IF_TRUE_W_RET_N_LOG(condition, val, message) \
269   do {                                                        \
270     if ((condition)) {                                        \
271       MS_LOG(ERROR) << message;                               \
272       return val;                                             \
273     }                                                         \
274   } while (0)
275 #endif  // MINDSPORE_CCSRC_C_API_SRC_UTILS_H_
276