1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/c_api/ms/node.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "c_api/src/helper.h"
20 #include "c_api/src/common.h"
21 #include "c_api/src/utils.h"
22 #include "base/base.h"
23 #include "ir/param_info.h"
24 #include "ir/anf.h"
25 #include "ir/scope.h"
26 #include "ir/func_graph_cloner.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "kernel/oplib/oplib.h"
29 #include "kernel/oplib/opinfo.h"
30 #include "abstract/dshape.h"
31 #include "pipeline/pynative/base.h"
32 #include "pipeline/pynative/pynative_utils.h"
33 #include "mindspore/core/ops/other_ops.h"
34
35 constexpr size_t firstInIdx = 1;
36 constexpr size_t secondInIdx = 2;
37 constexpr size_t switchInputNum = 3;
38 static const size_t maxMallocSize = GetMaxMallocSize();
MSNewOp(ResMgrHandle res_mgr,GraphHandle graph,const char * op_type,Handle const inputs[],size_t input_num,const char * const * attr_names,ValueHandle attrs[],size_t attr_num)39 NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, Handle const inputs[],
40 size_t input_num, const char *const *attr_names, ValueHandle attrs[], size_t attr_num) {
41 if (res_mgr == nullptr || graph == nullptr || op_type == nullptr || inputs == nullptr) {
42 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [op_type] or [inputs] is nullptr.";
43 return nullptr;
44 }
45 // convert raw input pointer to source shared pointer
46 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
47 if (res_fg == nullptr) {
48 MS_LOG(ERROR) << "Get source pointer failed.";
49 return nullptr;
50 }
51 auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
52 std::vector<AnfNodePtr> cnode_inputs{};
53 mindspore::AbstractBasePtrList abs_list{};
54 auto prim = std::make_shared<PrimitiveImpl>(op_type);
55 if (attr_names != nullptr && attrs != nullptr) {
56 auto ret = OpSetAttrs(res_mgr, prim, attr_names, attrs, attr_num);
57 if (ret != RET_OK) {
58 MS_LOG(ERROR) << "Op set attributes failed.";
59 return nullptr;
60 }
61 }
62 auto prim_node = mindspore::NewValueNode(prim);
63 cnode_inputs.push_back(prim_node);
64 CNodePtr cnode = nullptr;
65 try {
66 for (size_t i = 0; i < input_num; ++i) {
67 auto input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
68 MS_EXCEPTION_IF_NULL(input);
69 if (input->isa<ParameterImpl>() && input->func_graph() != res_fg) {
70 (void)res_fg->AddFreeVariable(input);
71 }
72 ConvertConstScalarInputToTensor(input);
73 cnode_inputs.push_back(input);
74 abs_list.push_back(input->abstract());
75 }
76 cnode = res_fg->NewCNodeInOrder(cnode_inputs);
77 MS_EXCEPTION_IF_NULL(cnode);
78 if (res_mgr_ptr->GetInfer()) {
79 auto out_abs = OpInferShapeAndType(prim, abs_list);
80 cnode->set_abstract(out_abs);
81 }
82 } catch (const std::exception &e) {
83 MS_LOG(ERROR) << "FuncGraph create new operator failed. Error info: " << e.what();
84 return nullptr;
85 }
86 MS_LOG(INFO) << "Add Operator" << op_type;
87 return GetRawPtr(res_mgr, cnode);
88 }
89
MSPackNodesTuple(ResMgrHandle res_mgr,GraphHandle graph,Handle const nodes[],size_t node_num)90 NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, Handle const nodes[], size_t node_num) {
91 if (res_mgr == nullptr || graph == nullptr || nodes == nullptr) {
92 MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [nodes] is nullptr.";
93 return nullptr;
94 }
95 CNodePtr make_tuple_cnode = nullptr;
96 try {
97 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
98 MS_EXCEPTION_IF_NULL(res_fg);
99 std::vector<AnfNodePtr> in_nodes{NewValueNode(mindspore::prim::kPrimMakeTuple)};
100 mindspore::AbstractBasePtrList abs_list{};
101 for (size_t i = 0; i < node_num; ++i) {
102 auto in_node = GetSrcPtr<AnfNodePtr>(res_mgr, nodes[i]);
103 MS_EXCEPTION_IF_NULL(in_node);
104 in_nodes.push_back(in_node);
105 ConvertConstScalarInputToTensor(in_node);
106 abs_list.push_back(in_node->abstract());
107 }
108 make_tuple_cnode = res_fg->NewCNodeInOrder(in_nodes);
109 MS_EXCEPTION_IF_NULL(make_tuple_cnode);
110 make_tuple_cnode->set_abstract(std::make_shared<AbstractTupleImpl>(abs_list));
111 } catch (const std::exception &e) {
112 MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
113 return nullptr;
114 }
115 return GetRawPtr(res_mgr, make_tuple_cnode);
116 }
117
MSOpGetSpecOutput(ResMgrHandle res_mgr,GraphHandle graph,ConstNodeHandle op,size_t i)118 NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op, size_t i) {
119 if (res_mgr == nullptr || graph == nullptr || op == nullptr) {
120 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
121 return nullptr;
122 }
123 CNodePtr ret_node = nullptr;
124 try {
125 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
126 MS_EXCEPTION_IF_NULL(res_fg);
127 auto cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
128 MS_EXCEPTION_IF_NULL(cnode);
129 auto abs = cnode->abstract();
130 if (abs == nullptr) {
131 MS_LOG(ERROR) << "Input op's abstract is nullptr!";
132 return nullptr;
133 }
134 if (abs->isa<mindspore::abstract::AbstractTuple>()) {
135 auto branch_num = abs->cast<mindspore::abstract::AbstractTuplePtr>()->size();
136 if (i >= branch_num) {
137 MS_LOG(ERROR) << "Invalid output branch index, it should be less than " << branch_num << ", but got: " << i;
138 return nullptr;
139 }
140 auto idx = mindspore::NewValueNode(mindspore::SizeToLong(i));
141 auto abs_scalar = std::make_shared<mindspore::abstract::AbstractScalar>(mindspore::SizeToInt(i));
142 idx->set_abstract(abs_scalar);
143 ret_node = res_fg->NewCNodeInOrder({NewValueNode(mindspore::prim::kPrimTupleGetItem), cnode, idx});
144 MS_EXCEPTION_IF_NULL(ret_node);
145 ret_node->set_abstract(abs->cast<mindspore::abstract::AbstractTuplePtr>()->elements()[i]);
146 } else {
147 if (i >= 1) {
148 MS_LOG(ERROR) << "Invalid output index. The op has only one output, so the output index should be 0, or you can"
149 " directly use this op as the output without calling this function, but got: "
150 << i;
151 return nullptr;
152 }
153 MS_LOG(WARNING) << "The op has only one output, you can directly use this op as the output without calling this "
154 "function. Now the op itself is returned.";
155 ret_node = cnode;
156 }
157 } catch (const std::exception &e) {
158 MS_LOG(ERROR) << "FuncGraph get output failed. Error info: " << e.what();
159 return nullptr;
160 }
161 return GetRawPtr(res_mgr, ret_node);
162 }
163
BuildSwitchStructure(ResMgrHandle res_mgr,GraphHandle graph,NodeHandle const switch_input[],size_t input_num,bool set_fg_out)164 CNodePtr BuildSwitchStructure(ResMgrHandle res_mgr, GraphHandle graph, NodeHandle const switch_input[],
165 size_t input_num, bool set_fg_out) {
166 MS_EXCEPTION_IF_NULL(res_mgr);
167 MS_EXCEPTION_IF_NULL(graph);
168 MS_EXCEPTION_IF_NULL(switch_input);
169 MS_EXCEPTION_IF_CHECK_FAIL(input_num == switchInputNum, "Switch's input number must be 3!");
170 NodeHandle switch_op = MSNewOp(res_mgr, graph, "Switch", switch_input, input_num, NULL, NULL, 0);
171 if (switch_op == nullptr) {
172 MS_LOG(ERROR) << "Get Switch op failed!";
173 return nullptr;
174 }
175 auto src_switch = GetSrcPtr<CNodePtr>(res_mgr, switch_op);
176 MS_EXCEPTION_IF_NULL(src_switch);
177 auto fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
178 MS_EXCEPTION_IF_NULL(fg);
179 CNodePtr switch_call = fg->NewCNodeInOrder({src_switch});
180 MS_EXCEPTION_IF_NULL(switch_call);
181 if (set_fg_out) {
182 fg->set_output(switch_call);
183 }
184 auto first_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[firstInIdx]);
185 MS_EXCEPTION_IF_NULL(first_node);
186 auto second_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[secondInIdx]);
187 MS_EXCEPTION_IF_NULL(second_node);
188 // AddFuncGraphCNodeIndex is used to set cnode_index. A funcgraph's cnode_index is a list of pair
189 // with pair-struct is (CNODE, index). The CNODE is in another funcgraph, who uses the funcgraph as its input.
190 // for eg. if fg1's cnode A uses fg2 as A's first input, then fg2's conde_index is (A, 1)
191 if (first_node->isa<ValueNodeImpl>()) {
192 fg->AddValueNode(first_node);
193 if (mindspore::IsValueNode<FuncGraphImpl>(first_node)) {
194 auto used = mindspore::GetValueNode<FuncGraphPtr>(first_node);
195 used->AddFuncGraphCNodeIndex(
196 std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, firstInIdx + 1)));
197 (void)fg->AddFuncGraphUsed(used);
198 }
199 }
200 if (second_node->isa<ValueNodeImpl>()) {
201 fg->AddValueNode(second_node);
202 if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
203 auto used = mindspore::GetValueNode<FuncGraphPtr>(second_node);
204 used->AddFuncGraphCNodeIndex(
205 std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, secondInIdx + 1)));
206 (void)fg->AddFuncGraphUsed(used);
207 }
208 }
209 // Switch-call's abstract is equal to second branch.
210 if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
211 auto sub_fg = mindspore::GetValueNode<FuncGraphPtr>(second_node);
212 switch_call->set_abstract(sub_fg->output()->abstract());
213 }
214 return switch_call;
215 }
216
MSNewSwitch(ResMgrHandle res_mgr,GraphHandle graph,Handle cond,ConstGraphHandle true_br,ConstGraphHandle false_br)217 NodeHandle MSNewSwitch(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, ConstGraphHandle true_br,
218 ConstGraphHandle false_br) {
219 if (res_mgr == nullptr || graph == nullptr || cond == nullptr || true_br == nullptr || false_br == nullptr) {
220 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [true_br] or [false_br] is nullptr.";
221 return nullptr;
222 }
223 try {
224 auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
225 MS_EXCEPTION_IF_NULL(src_cond);
226 NodeHandle cond_raw_ptr = nullptr;
227 if (src_cond->isa<FuncGraphImpl>()) {
228 auto cond_graph = src_cond->cast<FuncGraphPtr>();
229 MS_EXCEPTION_IF_NULL(cond_graph);
230 auto cond_node = mindspore::NewValueNode(cond_graph);
231 cond_node->set_abstract(cond_graph->ToAbstract());
232 cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
233 } else {
234 cond_raw_ptr = cond;
235 }
236 auto true_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, true_br);
237 MS_EXCEPTION_IF_NULL(true_fg);
238 auto true_node = mindspore::NewValueNode(true_fg);
239 true_node->set_abstract(true_fg->ToAbstract());
240 NodeHandle true_raw_ptr = GetRawPtr(res_mgr, true_node);
241
242 auto false_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, false_br);
243 MS_EXCEPTION_IF_NULL(false_fg);
244 auto false_node = mindspore::NewValueNode(false_fg);
245 false_node->set_abstract(false_fg->ToAbstract());
246 NodeHandle false_raw_ptr = GetRawPtr(res_mgr, false_node);
247
248 NodeHandle switch_input[] = {cond_raw_ptr, true_raw_ptr, false_raw_ptr};
249 auto switch_call = BuildSwitchStructure(res_mgr, graph, switch_input, switchInputNum, false);
250 MS_EXCEPTION_IF_NULL(switch_call);
251 return GetRawPtr(res_mgr, switch_call);
252 } catch (const std::exception &e) {
253 MS_LOG(ERROR) << "New Switch node failed. Error info: " << e.what();
254 return nullptr;
255 }
256 }
257
HandleFVInWhileGraph(const FuncGraphPtr & main_fg,const FuncGraphPtr & body_fg,const FuncGraphPtr & after_fg)258 void HandleFVInWhileGraph(const FuncGraphPtr &main_fg, const FuncGraphPtr &body_fg, const FuncGraphPtr &after_fg) {
259 std::vector<AnfNodePtr> fv_to_restore{};
260 auto body_fvs = body_fg->free_variables();
261 for (const auto &fv : body_fvs) {
262 auto fv_node = fv.first;
263 MS_EXCEPTION_IF_NULL(fv_node);
264 if (fv_node->func_graph() != main_fg &&
265 std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
266 fv_to_restore.push_back(fv_node);
267 }
268 }
269 auto after_fvs = after_fg->free_variables();
270 for (const auto &fv : after_fvs) {
271 auto fv_node = fv.first;
272 MS_EXCEPTION_IF_NULL(fv_node);
273 if (fv_node->func_graph() != main_fg &&
274 std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
275 fv_to_restore.push_back(fv_node);
276 }
277 }
278
279 (void)mindspore::LiftingClone(main_fg);
280
281 auto main_manager = Manage(main_fg);
282 std::vector<AnfNodePtr> new_main_params{};
283 auto main_params = main_fg->parameters();
284 for (const auto &main_param : main_params) {
285 auto src_main_param = main_param->cast<ParameterPtr>();
286 MS_EXCEPTION_IF_NULL(src_main_param);
287 auto found_in_fv_list =
288 find_if(fv_to_restore.begin(), fv_to_restore.end(), [&main_param](const AnfNodePtr &fv_param) {
289 return !main_param->ToString().empty() && main_param->ToString() == fv_param->ToString();
290 });
291 if (found_in_fv_list != fv_to_restore.end()) {
292 (void)main_manager->Replace(main_param, *found_in_fv_list);
293 } else if (src_main_param->has_default()) {
294 auto const_input = mindspore::NewValueNode(src_main_param->default_param());
295 const_input->set_abstract(src_main_param->abstract());
296 (void)main_manager->Replace(main_param, const_input);
297 } else {
298 new_main_params.push_back(main_param);
299 }
300 }
301 main_fg->set_parameters(new_main_params);
302 }
303
MSNewWhile(ResMgrHandle res_mgr,GraphHandle graph,Handle cond,GraphHandle body_graph,GraphHandle after_graph)304 NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, GraphHandle body_graph,
305 GraphHandle after_graph) {
306 if (res_mgr == nullptr || graph == nullptr || cond == nullptr || body_graph == nullptr || after_graph == nullptr) {
307 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [body_graph] or [after_graph] is nullptr.";
308 return nullptr;
309 }
310 try {
311 auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
312 MS_EXCEPTION_IF_NULL(src_cond);
313 NodeHandle cond_raw_ptr = nullptr;
314 GraphHandle cond_graph = nullptr;
315 FuncGraphPtr src_cond_graph = nullptr;
316 auto main_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
317 if (src_cond->isa<FuncGraphImpl>()) {
318 cond_graph = cond;
319 src_cond_graph = src_cond->cast<FuncGraphPtr>();
320 MS_EXCEPTION_IF_NULL(src_cond_graph);
321 auto cond_node = src_cond_graph->output();
322 MS_EXCEPTION_IF_NULL(cond_node);
323 cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
324 } else {
325 auto cond_fg = std::make_shared<FuncGraphImpl>();
326 MS_EXCEPTION_IF_NULL(cond_fg);
327 cond_graph = GetRawPtr(res_mgr, cond_fg);
328 MS_EXCEPTION_IF_NULL(cond_graph);
329 src_cond_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, cond_graph);
330 MS_EXCEPTION_IF_NULL(src_cond_graph);
331 (void)main_fg->AddFuncGraphUsed(src_cond_graph);
332 if (src_cond->isa<CNodeImpl>()) {
333 auto cond_node = src_cond->cast<CNodePtr>();
334 MS_EXCEPTION_IF_NULL(cond_node);
335 auto new_cond = src_cond_graph->NewCNodeInOrder(cond_node->inputs());
336 MS_EXCEPTION_IF_NULL(new_cond);
337 new_cond->set_abstract(cond_node->abstract());
338 cond_raw_ptr = GetRawPtr(res_mgr, new_cond);
339 } else {
340 cond_raw_ptr = cond;
341 }
342 }
343
344 auto body_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, body_graph);
345 MS_EXCEPTION_IF_NULL(body_fg);
346 auto body_node = mindspore::NewValueNode(body_fg);
347 body_node->set_abstract(body_fg->ToAbstract());
348 NodeHandle body_raw_ptr = GetRawPtr(res_mgr, body_node);
349
350 auto after_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, after_graph);
351 MS_EXCEPTION_IF_NULL(after_fg);
352 auto after_node = mindspore::NewValueNode(after_fg);
353 after_node->set_abstract(after_fg->ToAbstract());
354 NodeHandle after_raw_ptr = GetRawPtr(res_mgr, after_node);
355
356 NodeHandle switch_input[] = {cond_raw_ptr, body_raw_ptr, after_raw_ptr};
357 (void)BuildSwitchStructure(res_mgr, cond_graph, switch_input, switchInputNum, true);
358
359 // handle main graph call
360 NodeHandle main_func_call = MSNewFuncCallNode(res_mgr, graph, cond_graph, nullptr, 0);
361 auto src_call = GetSrcPtr<AnfNodePtr>(res_mgr, main_func_call);
362 main_fg->set_output(src_call);
363
364 // handle free parameters in while graphs
365 HandleFVInWhileGraph(main_fg, body_fg, after_fg);
366
367 // handle multi outputs in body graph
368 auto sub_graph_node = mindspore::NewValueNode(src_cond_graph);
369 sub_graph_node->set_abstract(src_cond_graph->ToAbstract());
370 std::vector<AnfNodePtr> sub_input_nodes{sub_graph_node};
371 auto body_out_node = body_fg->output();
372 MS_EXCEPTION_IF_NULL(body_out_node);
373 if (IsPrimitiveCNode(body_out_node, mindspore::prim::kPrimMakeTuple)) {
374 auto body_out_cnode = body_out_node->cast<CNodePtr>();
375 for (size_t i = 1; i < body_out_cnode->size(); i++) {
376 sub_input_nodes.push_back(body_out_cnode->input(i));
377 }
378 } else {
379 sub_input_nodes.push_back(body_out_node);
380 }
381 auto body_func_call = body_fg->NewCNodeInOrder(sub_input_nodes);
382 MS_EXCEPTION_IF_NULL(src_cond_graph->output());
383 MS_EXCEPTION_IF_NULL(body_func_call);
384 body_func_call->set_abstract(src_cond_graph->output()->abstract());
385 body_fg->set_output(body_func_call);
386 return main_func_call;
387 } catch (const std::exception &e) {
388 MS_LOG(ERROR) << "New While node failed. Error info: " << e.what();
389 return nullptr;
390 }
391 }
392
CustomOpInferShape(const CustomOpInfo & info,const std::vector<AbstractBasePtr> & input_args)393 std::vector<BaseShapePtr> CustomOpInferShape(const CustomOpInfo &info, const std::vector<AbstractBasePtr> &input_args) {
394 auto dyn_arr_deleter = [](int64_t **x, size_t dims) {
395 std::for_each(x, x + dims, std::default_delete<int64_t[]>());
396 delete[] x;
397 };
398 if (info.output_shapes != nullptr) {
399 if (info.output_dims == nullptr) {
400 MS_LOG(ERROR) << "Output dims must be given if output shapes are specified!";
401 return {};
402 }
403 auto infer_shape = BuildShape(info.output_shapes, info.output_dims, info.output_num);
404 return infer_shape;
405 } else if (info.shape_infer_func != nullptr) {
406 size_t input_num = info.input_num;
407 size_t output_num = info.output_num;
408 MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num * sizeof(size_t) > maxMallocSize, {},
409 "The input_num is too large for memory allocation.");
410 MS_ERROR_IF_TRUE_W_RET_N_LOG(output_num * sizeof(size_t) > maxMallocSize, {},
411 "The output_num is too large for memory allocation.");
412 auto out_dims_arr = std::make_unique<size_t[]>(output_num);
413 std::unique_ptr<int64_t *, std::function<void(int64_t **)>> out_shapes_arr(
414 new (std::nothrow) int64_t *[output_num](), std::bind(dyn_arr_deleter, std::placeholders::_1, output_num));
415 for (size_t i = 0; i < output_num; i++) {
416 (out_shapes_arr.get())[i] = new int64_t[MAX_DIMS];
417 }
418 auto in_dims_arr = std::make_unique<size_t[]>(input_num);
419 std::unique_ptr<int64_t *, std::function<void(int64_t **)>> in_shapes_arr(
420 new (std::nothrow) int64_t *[input_num](), std::bind(dyn_arr_deleter, std::placeholders::_1, input_num));
421 for (size_t i = 0; i < input_num; i++) {
422 auto in_shape = input_args[i]->BuildShape();
423 MS_EXCEPTION_IF_NULL(in_shape);
424 auto in_shape_ptr = in_shape->cast<ShapePtr>();
425 MS_EXCEPTION_IF_NULL(in_shape_ptr);
426 auto in_shape_vec = in_shape_ptr->shape();
427 auto in_shape_dim = in_shape_vec.size();
428 in_dims_arr[i] = in_shape_dim;
429 MS_ERROR_IF_TRUE_W_RET_N_LOG(in_shape_dim * sizeof(size_t) > maxMallocSize, {},
430 "The in_shape_dim is too large for memory allocation.");
431 (in_shapes_arr.get())[i] = new int64_t[in_shape_dim];
432 for (size_t j = 0; j < in_shape_dim; j++) {
433 (in_shapes_arr.get())[i][j] = in_shape_vec[j];
434 }
435 }
436 auto ret = info.shape_infer_func(in_shapes_arr.get(), in_dims_arr.get(), input_num, out_shapes_arr.get(),
437 out_dims_arr.get(), output_num);
438 if (ret != RET_OK) {
439 MS_LOG(ERROR) << "Failed to call the shape infer function of custom op!";
440 return {};
441 }
442 auto infer_shape = BuildShape(out_shapes_arr.get(), out_dims_arr.get(), output_num);
443 return infer_shape;
444 } else {
445 MS_LOG(ERROR) << "Either output shape or output shape infer function must be specified!";
446 return {};
447 }
448 }
449
CustomOpInferType(const CustomOpInfo & info,const std::vector<AbstractBasePtr> & input_args)450 std::vector<TypePtr> CustomOpInferType(const CustomOpInfo &info, const std::vector<AbstractBasePtr> &input_args) {
451 if (info.output_dtypes != nullptr) {
452 auto infer_dtype = BuildType(info.output_dtypes, info.output_num);
453 return infer_dtype;
454 } else if (info.shape_infer_func != nullptr) {
455 size_t input_num = info.input_num;
456 size_t output_num = info.output_num;
457 auto in_dtypes_arr = std::make_unique<DataTypeC[]>(input_num);
458 auto out_dtypes_arr = std::make_unique<DataTypeC[]>(output_num);
459 for (size_t i = 0; i < input_num; i++) {
460 auto in_type = input_args[i]->BuildType();
461 MS_EXCEPTION_IF_NULL(in_type);
462 auto real_type = in_type;
463 if (in_type->isa<TensorTypeImpl>()) {
464 auto tensor_type = in_type->cast<TensorTypePtr>();
465 real_type = tensor_type->element();
466 }
467 auto in_type_id = (enum DataTypeC)(real_type->type_id());
468 in_dtypes_arr[i] = in_type_id;
469 }
470 STATUS ret = info.dtype_infer_func(in_dtypes_arr.get(), input_num, out_dtypes_arr.get(), output_num);
471 if (ret != RET_OK) {
472 MS_LOG(ERROR) << "Failed to call the dtype infer function of custom op!";
473 return {};
474 }
475 auto infer_dtype = BuildType(out_dtypes_arr.get(), output_num);
476 return infer_dtype;
477 } else {
478 MS_LOG(ERROR) << "Either output dtype or output dtype infer function must be specified!";
479 return {};
480 }
481 }
482
MSNewCustomOp(ResMgrHandle res_mgr,GraphHandle graph,Handle const inputs[],size_t input_num,CustomOpInfo info)483 NodeHandle MSNewCustomOp(ResMgrHandle res_mgr, GraphHandle graph, Handle const inputs[], size_t input_num,
484 CustomOpInfo info) {
485 if (res_mgr == nullptr || graph == nullptr) {
486 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
487 return nullptr;
488 }
489 MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num != info.input_num, nullptr,
490 "Input node number is not matched with the input number specified in custom op info.");
491 auto ret = CheckCustomOpInfo(info);
492 MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Invalid custom op info.");
493 try {
494 auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
495 auto org_infer = res_mgr_ptr->GetInfer();
496 res_mgr_ptr->SetInfer(false);
497 NodeHandle custom_op =
498 MSNewOp(res_mgr, graph, "Custom", inputs, info.input_num, info.attr_names, info.attr_values, info.attr_num);
499 MS_ERROR_IF_TRUE_W_RET_N_LOG(custom_op == nullptr, nullptr, "Create Custom op failed!");
500 res_mgr_ptr->SetInfer(org_infer);
501 // Supplement necessary attributes
502 ret = MSOpSetAttrString(res_mgr, custom_op, mindspore::kAttrFuncType, info.func_type);
503 MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Custom op set func type attribute failed.");
504 ret = MSOpSetAttrString(res_mgr, custom_op, mindspore::kAttrFuncName, info.func_name);
505 MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Custom op set func name attribute failed.");
506 // Build json object
507 nlohmann::json json_obj = ConvertOpInfoToJson(info);
508 MS_ERROR_IF_TRUE_W_RET_N_LOG(json_obj.empty(), nullptr, "Failed to convert op info to json.");
509 // Create op info and set info map
510 auto op_name = json_obj.at(mindspore::kernel::kOpName).get<std::string>();
511 auto imply_type = json_obj.at(mindspore::kernel::kImplyType).get<std::string>();
512 std::string func_name = info.func_name;
513 std::string target_name = info.target;
514 auto iter = mindspore::kernel::kImplyTypeStrToEnumMap.find(imply_type);
515 if (iter == mindspore::kernel::kImplyTypeStrToEnumMap.end()) {
516 MS_LOG(ERROR) << "Not support imply_type: " << imply_type;
517 return nullptr;
518 }
519 auto op_info = mindspore::kernel::OpLib::DecodeOpInfo(json_obj, iter->second, "");
520 if (op_info == nullptr) {
521 MS_LOG(ERROR) << "Decode op info failed: func_name: " << func_name << " imply_type " << imply_type;
522 return nullptr;
523 }
524 op_info->set_processor(imply_type);
525 auto key = op_name + imply_type;
526 auto &op_infos = mindspore::kernel::OpLib::GetOpInfoMap();
527 (void)op_infos[iter->second].insert(std::pair<std::string, mindspore::kernel::OpInfoPtr>(key, op_info));
528 // Infer shape and type
529 mindspore::AbstractBasePtrList abs_list{};
530 for (size_t i = 0; i < input_num; ++i) {
531 auto in_node = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
532 MS_EXCEPTION_IF_NULL(in_node);
533 abs_list.push_back(in_node->abstract());
534 }
535 auto infer_shape = CustomOpInferShape(info, abs_list);
536 auto infer_type = CustomOpInferType(info, abs_list);
537 AbstractBasePtr custom_abs = BuildAbstract(infer_shape, infer_type);
538 MS_EXCEPTION_IF_NULL(custom_abs);
539 auto src_op = GetSrcPtr<CNodePtr>(res_mgr, custom_op);
540 MS_EXCEPTION_IF_NULL(src_op);
541 src_op->set_abstract(custom_abs);
542 return custom_op;
543 } catch (const std::exception &e) {
544 MS_LOG(ERROR) << "Get custom op failed. Error info: " << e.what();
545 return nullptr;
546 }
547 }
548
MSOpGetInput(ResMgrHandle res_mgr,ConstNodeHandle op,size_t i)549 NodeHandle MSOpGetInput(ResMgrHandle res_mgr, ConstNodeHandle op, size_t i) {
550 if (res_mgr == nullptr || op == nullptr) {
551 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
552 return nullptr;
553 }
554 mindspore::AnfNodePtr anf_node = nullptr;
555 try {
556 auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
557 MS_EXCEPTION_IF_NULL(src_cnode);
558 if (i >= src_cnode->size() - 1) {
559 MS_LOG(ERROR) << "Invalid input index, it should be less than " << src_cnode->size() - 1 << ", but got: " << i;
560 return nullptr;
561 }
562 anf_node = src_cnode->input(i + 1);
563 } catch (const std::exception &e) {
564 MS_LOG(ERROR) << "Get input from CNode failed. Error info: " << e.what();
565 return nullptr;
566 }
567 return GetRawPtr(res_mgr, anf_node);
568 }
569
MSOpGetInputsNum(ResMgrHandle res_mgr,ConstNodeHandle op,STATUS * error)570 size_t MSOpGetInputsNum(ResMgrHandle res_mgr, ConstNodeHandle op, STATUS *error) {
571 if (error == nullptr) {
572 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
573 return 0;
574 }
575 if (res_mgr == nullptr || op == nullptr) {
576 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
577 *error = RET_NULL_PTR;
578 return 0;
579 }
580 size_t input_num;
581 try {
582 auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
583 MS_EXCEPTION_IF_NULL(src_cnode);
584 input_num = src_cnode->size() - 1;
585 } catch (const std::exception &e) {
586 MS_LOG(ERROR) << "FuncGraph get input number failed. Error info: " << e.what();
587 *error = RET_ERROR;
588 return 0;
589 }
590 *error = RET_OK;
591 return input_num;
592 }
593
MSOpGetInputs(ResMgrHandle res_mgr,ConstNodeHandle op,NodeHandle inputs[],size_t input_num)594 STATUS MSOpGetInputs(ResMgrHandle res_mgr, ConstNodeHandle op, NodeHandle inputs[], size_t input_num) {
595 if (res_mgr == nullptr || op == nullptr || inputs == nullptr) {
596 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [inputs] is nullptr.";
597 return RET_NULL_PTR;
598 }
599 try {
600 auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
601 MS_EXCEPTION_IF_NULL(src_cnode);
602 auto in_num = src_cnode->size() - 1;
603 if (in_num != input_num) {
604 MS_LOG(ERROR) << "Invalid input number, it should be: " << in_num << ", but got: " << input_num;
605 return RET_ERROR;
606 }
607 auto cnode_inputs = src_cnode->inputs();
608 for (size_t i = 0; i < input_num; i++) {
609 inputs[i] = GetRawPtr(res_mgr, cnode_inputs[i + 1]);
610 }
611 } catch (const std::exception &e) {
612 MS_LOG(ERROR) << "Get inputs from CNode failed. Error info: " << e.what();
613 return RET_ERROR;
614 }
615 return RET_OK;
616 }
617
MSOpGetOutputDimension(ResMgrHandle res_mgr,ConstNodeHandle op,size_t output_index,STATUS * ret)618 size_t MSOpGetOutputDimension(ResMgrHandle res_mgr, ConstNodeHandle op, size_t output_index, STATUS *ret) {
619 if (res_mgr == nullptr || op == nullptr) {
620 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
621 *ret = RET_NULL_PTR;
622 return 0;
623 }
624 try {
625 auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
626 MS_EXCEPTION_IF_NULL(src_cnode);
627 std::vector<int64_t> shape = mindspore::common::AnfAlgo::GetOutputInferShape(src_cnode, output_index);
628 return shape.size();
629 } catch (const std::exception &e) {
630 MS_LOG(ERROR) << "Get Shape from OP/CNode failed. Error info: " << e.what();
631 *ret = RET_ERROR;
632 return 0;
633 }
634 }
635
MSOpGetOutputShape(ResMgrHandle res_mgr,ConstNodeHandle op,int64_t shape_ret[],size_t dim,size_t output_index)636 STATUS MSOpGetOutputShape(ResMgrHandle res_mgr, ConstNodeHandle op, int64_t shape_ret[], size_t dim,
637 size_t output_index) {
638 if (res_mgr == nullptr || op == nullptr) {
639 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
640 return RET_NULL_PTR;
641 }
642 try {
643 auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
644 MS_EXCEPTION_IF_NULL(src_cnode);
645 std::vector<int64_t> shape = mindspore::common::AnfAlgo::GetOutputInferShape(src_cnode, output_index);
646 MS_EXCEPTION_IF_CHECK_FAIL(
647 dim >= shape.size(),
648 "Input dimension less than the actual Dimension. Please ensure shape_ret have enough space.");
649 (void)std::copy(shape.begin(), shape.end(), shape_ret);
650 } catch (const std::exception &e) {
651 MS_LOG(ERROR) << "Get Shape from OP/CNode failed. Error info: " << e.what();
652 return RET_ERROR;
653 }
654 return RET_OK;
655 }
656
MSNewFuncCallNode(ResMgrHandle res_mgr,GraphHandle graph,ConstGraphHandle sub_graph,Handle const inputs[],size_t input_num)657 NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph, ConstGraphHandle sub_graph, Handle const inputs[],
658 size_t input_num) {
659 if (res_mgr == nullptr || graph == nullptr || sub_graph == nullptr) {
660 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [sub_graph] is nullptr.";
661 return nullptr;
662 }
663 CNodePtr cnode = nullptr;
664 try {
665 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
666 MS_EXCEPTION_IF_NULL(res_fg);
667 auto res_sub_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, sub_graph);
668 MS_EXCEPTION_IF_NULL(res_sub_fg);
669 auto sub_node = mindspore::NewValueNode(res_sub_fg);
670 sub_node->set_abstract(res_sub_fg->ToAbstract());
671 std::vector<AnfNodePtr> cnode_inputs{sub_node};
672 for (size_t i = 0; i < input_num; ++i) {
673 auto cnode_input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
674 MS_EXCEPTION_IF_NULL(cnode_input);
675 cnode_inputs.push_back(cnode_input);
676 }
677 cnode = res_fg->NewCNodeInOrder(cnode_inputs);
678 MS_EXCEPTION_IF_NULL(res_sub_fg->output());
679 cnode->set_abstract(res_sub_fg->output()->abstract());
680 (void)res_fg->AddFuncGraphUsed(res_sub_fg);
681 } catch (const std::exception &e) {
682 MS_LOG(ERROR) << "FuncGraph create SubGraph node failed. Error info: " << e.what();
683 return nullptr;
684 }
685 MS_LOG(INFO) << "Add function call node";
686 return GetRawPtr(res_mgr, cnode);
687 }
688
MSNewPlaceholder(ResMgrHandle res_mgr,GraphHandle graph,DataTypeC type,const int64_t shape[],size_t shape_size)689 NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph, DataTypeC type, const int64_t shape[],
690 size_t shape_size) {
691 if (res_mgr == nullptr || graph == nullptr) {
692 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
693 return nullptr;
694 }
695 ParameterPtr param = nullptr;
696 try {
697 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
698 MS_EXCEPTION_IF_NULL(res_fg);
699 param = res_fg->add_parameter();
700 auto type_ptr = mindspore::TypeIdToType(mindspore::TypeId(type));
701 AbstractBasePtr abs = GetAbstract(type_ptr, shape, shape_size, true);
702 param->set_abstract(abs);
703 } catch (const std::exception &e) {
704 MS_LOG(ERROR) << "FuncGraph add parameter failed. Error info: " << e.what();
705 return nullptr;
706 }
707 return GetRawPtr(res_mgr, param);
708 }
709
MSNewVariableScalarFloat32(ResMgrHandle res_mgr,GraphHandle graph,float value)710 NodeHandle MSNewVariableScalarFloat32(ResMgrHandle res_mgr, GraphHandle graph, float value) {
711 if (res_mgr == nullptr || graph == nullptr) {
712 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
713 return nullptr;
714 }
715 ParameterPtr param = nullptr;
716 try {
717 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
718 MS_EXCEPTION_IF_NULL(res_fg);
719 param = GetScalarParam<float>(res_fg, value, mindspore::kNumberTypeFloat32);
720 MS_EXCEPTION_IF_NULL(param);
721 } catch (const std::exception &e) {
722 MS_LOG(ERROR) << "New Scalar Variable failed. Error info: " << e.what();
723 return nullptr;
724 }
725 return GetRawPtr(res_mgr, param);
726 }
727
MSNewVariableScalarInt32(ResMgrHandle res_mgr,GraphHandle graph,int value)728 NodeHandle MSNewVariableScalarInt32(ResMgrHandle res_mgr, GraphHandle graph, int value) {
729 if (res_mgr == nullptr || graph == nullptr) {
730 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
731 return nullptr;
732 }
733 ParameterPtr param = nullptr;
734 try {
735 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
736 MS_EXCEPTION_IF_NULL(res_fg);
737 param = GetScalarParam<float>(res_fg, value, mindspore::kNumberTypeInt32);
738 MS_EXCEPTION_IF_NULL(param);
739 } catch (const std::exception &e) {
740 MS_LOG(ERROR) << "New Scalar Variable failed. Error info: " << e.what();
741 return nullptr;
742 }
743 return GetRawPtr(res_mgr, param);
744 }
745
MSNewVariableArray(ResMgrHandle res_mgr,GraphHandle graph,void * data,DataTypeC type,const int64_t shape[],size_t shape_size,size_t data_len)746 NodeHandle MSNewVariableArray(ResMgrHandle res_mgr, GraphHandle graph, void *data, DataTypeC type,
747 const int64_t shape[], size_t shape_size, size_t data_len) {
748 if (res_mgr == nullptr || graph == nullptr || data == nullptr || shape == nullptr) {
749 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [data] or [shape] is nullptr.";
750 return nullptr;
751 }
752 ParameterPtr param = nullptr;
753 ShapeVector shape_vec(shape, shape + shape_size);
754 try {
755 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
756 MS_EXCEPTION_IF_NULL(res_fg);
757 param = res_fg->add_parameter();
758 auto tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data, data_len);
759 tensor->set_param_info(std::make_shared<mindspore::ParamInfo>());
760 param->set_abstract(tensor->ToAbstract());
761 param->set_default_param(tensor);
762 } catch (const std::exception &e) {
763 MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
764 return nullptr;
765 }
766 return GetRawPtr(res_mgr, param);
767 }
768
MSNewVariableFromTensor(ResMgrHandle res_mgr,GraphHandle graph,ConstTensorHandle tensor)769 NodeHandle MSNewVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, ConstTensorHandle tensor) {
770 if (res_mgr == nullptr || graph == nullptr || tensor == nullptr) {
771 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [tensor] is nullptr.";
772 return nullptr;
773 }
774 ParameterPtr param = nullptr;
775 try {
776 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
777 MS_EXCEPTION_IF_NULL(res_fg);
778 auto tensor_impl = GetSrcPtr<TensorPtr>(res_mgr, tensor);
779 MS_EXCEPTION_IF_NULL(tensor_impl);
780 param = res_fg->add_parameter();
781 param->set_abstract(tensor_impl->ToAbstract());
782 param->set_default_param(tensor_impl);
783 } catch (const std::exception &e) {
784 MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
785 return nullptr;
786 }
787 return GetRawPtr(res_mgr, param);
788 }
789
MSVariableArrayGetDataSize(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)790 size_t MSVariableArrayGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
791 if (error == nullptr) {
792 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
793 return 0;
794 }
795 if (res_mgr == nullptr || node == nullptr) {
796 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
797 *error = RET_NULL_PTR;
798 return 0;
799 }
800 try {
801 auto node_impl = GetSrcPtr<ParameterPtr>(res_mgr, node);
802 MS_EXCEPTION_IF_NULL(node_impl);
803 auto val = node_impl->default_param();
804 MS_EXCEPTION_IF_NULL(val);
805 auto tensor = val->cast<TensorPtr>();
806 MS_EXCEPTION_IF_NULL(tensor);
807 size_t data_size = tensor->Size();
808 *error = RET_OK;
809 return data_size;
810 } catch (const std::exception &e) {
811 MS_LOG(ERROR) << "Tensor Variable get data failed. Error info: " << e.what();
812 *error = RET_ERROR;
813 return 0;
814 }
815 }
816
MSVariableArrayGetData(ResMgrHandle res_mgr,ConstNodeHandle node)817 void *MSVariableArrayGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
818 if (res_mgr == nullptr || node == nullptr) {
819 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
820 return nullptr;
821 }
822 try {
823 auto node_impl = GetSrcPtr<ParameterPtr>(res_mgr, node);
824 MS_EXCEPTION_IF_NULL(node_impl);
825 auto val = node_impl->default_param();
826 MS_EXCEPTION_IF_NULL(val);
827 auto tensor = val->cast<TensorPtr>();
828 MS_EXCEPTION_IF_NULL(tensor);
829 void *data = tensor->data_c();
830 return data;
831 } catch (const std::exception &e) {
832 MS_LOG(ERROR) << "Tensor Variable get data failed. Error info: " << e.what();
833 return nullptr;
834 }
835 }
836
MSNewConstantArray(ResMgrHandle res_mgr,void * data,DataTypeC type,const int64_t shape[],size_t shape_size,size_t data_len)837 NodeHandle MSNewConstantArray(ResMgrHandle res_mgr, void *data, DataTypeC type, const int64_t shape[],
838 size_t shape_size, size_t data_len) {
839 if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
840 MS_LOG(ERROR) << "Input Handle [res_mgr] or [data] or [shape] is nullptr.";
841 return nullptr;
842 }
843 ShapeVector shape_vec(shape, shape + shape_size);
844 ValueNodePtr value_node = nullptr;
845 try {
846 auto tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data, data_len);
847 tensor->set_param_info(std::make_shared<mindspore::ParamInfo>());
848 value_node = mindspore::NewValueNode(tensor);
849 value_node->set_abstract(tensor->ToAbstract());
850 } catch (const std::exception &e) {
851 MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
852 return nullptr;
853 }
854 return GetRawPtr(res_mgr, value_node);
855 }
856
MSNewConstantFromTensor(ResMgrHandle res_mgr,TensorHandle tensor)857 NodeHandle MSNewConstantFromTensor(ResMgrHandle res_mgr, TensorHandle tensor) {
858 if (res_mgr == nullptr || tensor == nullptr) {
859 MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
860 return nullptr;
861 }
862 ValueNodePtr value_node = nullptr;
863 try {
864 auto tensor_impl = GetSrcPtr<TensorPtr>(res_mgr, tensor);
865 MS_EXCEPTION_IF_NULL(tensor_impl);
866 value_node = mindspore::NewValueNode(tensor_impl);
867 value_node->set_abstract(tensor_impl->ToAbstract());
868 } catch (const std::exception &e) {
869 MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
870 return nullptr;
871 }
872 return GetRawPtr(res_mgr, value_node);
873 }
874
MSConstantArrayGetDataSize(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)875 size_t MSConstantArrayGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
876 if (error == nullptr) {
877 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
878 return 0;
879 }
880 if (res_mgr == nullptr || node == nullptr) {
881 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
882 *error = RET_NULL_PTR;
883 return 0;
884 }
885 try {
886 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
887 MS_EXCEPTION_IF_NULL(node_impl);
888 auto val = node_impl->value();
889 MS_EXCEPTION_IF_NULL(val);
890 auto tensor = val->cast<TensorPtr>();
891 MS_EXCEPTION_IF_NULL(tensor);
892 size_t data_size = tensor->Size();
893 *error = RET_OK;
894 return data_size;
895 } catch (const std::exception &e) {
896 MS_LOG(ERROR) << "Tensor Constant get data failed. Error info: " << e.what();
897 *error = RET_ERROR;
898 return 0;
899 }
900 }
901
MSConstantArrayGetData(ResMgrHandle res_mgr,ConstNodeHandle node)902 void *MSConstantArrayGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
903 if (res_mgr == nullptr || node == nullptr) {
904 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
905 return nullptr;
906 }
907 try {
908 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
909 MS_EXCEPTION_IF_NULL(node_impl);
910 auto val = node_impl->value();
911 MS_EXCEPTION_IF_NULL(val);
912 auto tensor = val->cast<TensorPtr>();
913 MS_EXCEPTION_IF_NULL(tensor);
914 void *data = tensor->data_c();
915 return data;
916 } catch (const std::exception &e) {
917 MS_LOG(ERROR) << "Tensor Constant get data failed. Error info: " << e.what();
918 return nullptr;
919 }
920 }
921
MSNewConstantScalarFloat32(ResMgrHandle res_mgr,float value)922 NodeHandle MSNewConstantScalarFloat32(ResMgrHandle res_mgr, float value) {
923 MS_LOG(INFO) << "New Float32 Scalar Value!s";
924 if (res_mgr == nullptr) {
925 MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
926 return nullptr;
927 }
928 auto value_node = mindspore::NewValueNode(value);
929 value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
930 return GetRawPtr(res_mgr, value_node);
931 }
932
MSNewConstantScalarBool(ResMgrHandle res_mgr,bool value)933 NodeHandle MSNewConstantScalarBool(ResMgrHandle res_mgr, bool value) {
934 MS_LOG(INFO) << "New Bool Scalar Value!";
935 if (res_mgr == nullptr) {
936 MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
937 return nullptr;
938 }
939 auto value_node = mindspore::NewValueNode(value);
940 value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
941 return GetRawPtr(res_mgr, value_node);
942 }
943
MSNewConstantScalarInt32(ResMgrHandle res_mgr,int value)944 NodeHandle MSNewConstantScalarInt32(ResMgrHandle res_mgr, int value) {
945 MS_LOG(INFO) << "New Int32 Scalar Value!";
946 if (res_mgr == nullptr) {
947 MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
948 return nullptr;
949 }
950 auto value_node = mindspore::NewValueNode(value);
951 value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
952 return GetRawPtr(res_mgr, value_node);
953 }
954
MSNewConstantScalarInt64(ResMgrHandle res_mgr,int64_t value)955 NodeHandle MSNewConstantScalarInt64(ResMgrHandle res_mgr, int64_t value) {
956 MS_LOG(INFO) << "New Int64 Scalar Value!";
957 if (res_mgr == nullptr) {
958 MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
959 return nullptr;
960 }
961 auto value_node = mindspore::NewValueNode(value);
962 value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
963 return GetRawPtr(res_mgr, value_node);
964 }
965
MSNewConstantString(ResMgrHandle res_mgr,const char * str)966 NodeHandle MSNewConstantString(ResMgrHandle res_mgr, const char *str) {
967 MS_LOG(INFO) << "New String Scalar Value!";
968 if (res_mgr == nullptr || str == nullptr) {
969 MS_LOG(ERROR) << "Input Handle [res_mgr] or [str] is nullptr.";
970 return nullptr;
971 }
972 string str_val(str);
973 auto value_node = mindspore::NewValueNode(str_val);
974 value_node->set_abstract(std::make_shared<AbstractScalarImpl>(str_val));
975 return GetRawPtr(res_mgr, value_node);
976 }
977
MSNewConstantTupleInt64(ResMgrHandle res_mgr,const int64_t vec[],size_t size)978 NodeHandle MSNewConstantTupleInt64(ResMgrHandle res_mgr, const int64_t vec[], size_t size) {
979 MS_LOG(INFO) << "New Vector Value!";
980 if (res_mgr == nullptr || vec == nullptr) {
981 MS_LOG(ERROR) << "Input Handle [res_mgr] or [vec] is nullptr.";
982 return nullptr;
983 }
984 auto value_node = mindspore::NewValueNode(std::vector<int64_t>(vec, vec + size));
985 mindspore::AbstractBasePtrList abs_list = {};
986 for (size_t i = 0; i < size; i++) {
987 AbstractBasePtr base = std::make_shared<AbstractScalarImpl>(vec[i]);
988 abs_list.push_back(base);
989 }
990 auto abstract = std::make_shared<AbstractTupleImpl>(abs_list);
991 value_node->set_abstract(abstract);
992 return GetRawPtr(res_mgr, value_node);
993 }
994
MSNewConstantType(ResMgrHandle res_mgr,DataTypeC type)995 NodeHandle MSNewConstantType(ResMgrHandle res_mgr, DataTypeC type) {
996 MS_LOG(INFO) << "New Type Value: " << type;
997 if (res_mgr == nullptr) {
998 MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
999 return nullptr;
1000 }
1001 auto type_ptr = mindspore::TypeIdToType(mindspore::TypeId(type));
1002 auto value_node = mindspore::NewValueNode(type_ptr);
1003 auto abstract = std::make_shared<AbstractTypeImpl>(type_ptr);
1004 value_node->set_abstract(abstract);
1005 return GetRawPtr(res_mgr, value_node);
1006 }
1007
MSConstantScalarGetValueInt32(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1008 int MSConstantScalarGetValueInt32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1009 MS_LOG(INFO) << "Get Int32 Scalar Value!";
1010 if (error == nullptr) {
1011 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1012 return 0;
1013 }
1014 if (res_mgr == nullptr || node == nullptr) {
1015 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1016 *error = RET_NULL_PTR;
1017 return 0;
1018 }
1019 int ret_val = 0;
1020 *error = RET_OK;
1021 try {
1022 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1023 MS_EXCEPTION_IF_NULL(node_impl);
1024 auto val = node_impl->value();
1025 MS_EXCEPTION_IF_NULL(val);
1026 if (val->isa<TensorImpl>()) {
1027 auto val_tensor = val->cast<TensorPtr>();
1028 auto data = val_tensor->data_c();
1029 MS_EXCEPTION_IF_NULL(data);
1030 ret_val = static_cast<int *>(data)[0];
1031 } else if (val->isa<Int32ImmImpl>()) {
1032 auto val_imm = val->cast<Int32ImmPtr>();
1033 ret_val = val_imm->value();
1034 } else {
1035 MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1036 *error = RET_ERROR;
1037 }
1038 } catch (const std::exception &e) {
1039 MS_LOG(ERROR) << "Get Int32 Scalar value failed. Error info: " << e.what();
1040 *error = RET_ERROR;
1041 }
1042 return ret_val;
1043 }
1044
MSConstantScalarGetValueFloat32(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1045 float MSConstantScalarGetValueFloat32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1046 MS_LOG(INFO) << "Get Float32 Scalar Value!";
1047 if (error == nullptr) {
1048 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1049 return 0;
1050 }
1051 if (res_mgr == nullptr || node == nullptr) {
1052 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1053 *error = RET_NULL_PTR;
1054 return 0;
1055 }
1056 float ret_val = 0;
1057 *error = RET_OK;
1058 try {
1059 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1060 MS_EXCEPTION_IF_NULL(node_impl);
1061 auto val = node_impl->value();
1062 MS_EXCEPTION_IF_NULL(val);
1063 if (val->isa<TensorImpl>()) {
1064 auto val_tensor = val->cast<TensorPtr>();
1065 auto data = val_tensor->data_c();
1066 MS_EXCEPTION_IF_NULL(data);
1067 ret_val = static_cast<float *>(data)[0];
1068 } else if (val->isa<Float32ImmImpl>()) {
1069 auto val_imm = val->cast<Float32ImmPtr>();
1070 ret_val = val_imm->value();
1071 } else {
1072 MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1073 *error = RET_ERROR;
1074 }
1075 } catch (const std::exception &e) {
1076 MS_LOG(ERROR) << "Get Float32 Scalar value failed. Error info: " << e.what();
1077 *error = RET_ERROR;
1078 }
1079 return ret_val;
1080 }
1081
MSConstantScalarGetValueBool(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1082 bool MSConstantScalarGetValueBool(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1083 MS_LOG(INFO) << "Get Bool Scalar Value!";
1084 if (error == nullptr) {
1085 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1086 return false;
1087 }
1088 if (res_mgr == nullptr || node == nullptr) {
1089 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1090 *error = RET_NULL_PTR;
1091 return false;
1092 }
1093 int ret_val = false;
1094 *error = RET_OK;
1095 try {
1096 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1097 MS_EXCEPTION_IF_NULL(node_impl);
1098 auto val = node_impl->value();
1099 MS_EXCEPTION_IF_NULL(val);
1100 if (val->isa<TensorImpl>()) {
1101 auto val_tensor = val->cast<TensorPtr>();
1102 auto data = val_tensor->data_c();
1103 MS_EXCEPTION_IF_NULL(data);
1104 ret_val = static_cast<bool *>(data)[0];
1105 } else if (val->isa<BoolImmImpl>()) {
1106 auto val_imm = val->cast<BoolImmPtr>();
1107 ret_val = val_imm->value();
1108 } else {
1109 MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1110 *error = RET_ERROR;
1111 }
1112 } catch (const std::exception &e) {
1113 MS_LOG(ERROR) << "Get Bool Scalar value failed. Error info: " << e.what();
1114 *error = RET_ERROR;
1115 }
1116 return ret_val;
1117 }
1118
MSConstantScalarGetValueInt64(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1119 int64_t MSConstantScalarGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1120 MS_LOG(INFO) << "Get Int64 Scalar Value!";
1121 if (error == nullptr) {
1122 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1123 return 0;
1124 }
1125 if (res_mgr == nullptr || node == nullptr) {
1126 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1127 *error = RET_NULL_PTR;
1128 return 0;
1129 }
1130 int64_t ret_val = 0;
1131 *error = RET_OK;
1132 try {
1133 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1134 MS_EXCEPTION_IF_NULL(node_impl);
1135 auto val = node_impl->value();
1136 MS_EXCEPTION_IF_NULL(val);
1137 if (val->isa<TensorImpl>()) {
1138 auto val_tensor = val->cast<TensorPtr>();
1139 auto data = val_tensor->data_c();
1140 MS_EXCEPTION_IF_NULL(data);
1141 ret_val = static_cast<int64_t *>(data)[0];
1142 } else if (val->isa<Int64ImmImpl>()) {
1143 auto val_imm = val->cast<Int64ImmPtr>();
1144 ret_val = val_imm->value();
1145 } else {
1146 MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1147 *error = RET_ERROR;
1148 }
1149 } catch (const std::exception &e) {
1150 MS_LOG(ERROR) << "Get Int64 Scalar value failed. Error info: " << e.what();
1151 *error = RET_ERROR;
1152 }
1153 return ret_val;
1154 }
1155
MSConstantStringGetValue(ResMgrHandle res_mgr,ConstNodeHandle node,char str_buf[],size_t str_len)1156 STATUS MSConstantStringGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
1157 MS_LOG(INFO) << "Get String Constant Value!";
1158 if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
1159 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
1160 return RET_NULL_PTR;
1161 }
1162 try {
1163 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1164 MS_EXCEPTION_IF_NULL(node_impl);
1165 auto val = node_impl->value();
1166 MS_EXCEPTION_IF_NULL(val);
1167 auto val_str = val->cast<StringImmPtr>();
1168 std::string ret_val = val_str->value();
1169 size_t valid_size = ret_val.size() < str_len - 1 ? ret_val.size() : str_len - 1;
1170 for (size_t i = 0; i < valid_size; i++) {
1171 str_buf[i] = ret_val.c_str()[i];
1172 }
1173 str_buf[valid_size] = '\0';
1174 return RET_OK;
1175 } catch (const std::exception &e) {
1176 MS_LOG(ERROR) << "Get String Constant value failed. Error info: " << e.what();
1177 return RET_ERROR;
1178 }
1179 }
1180
MSConstantTupleGetSize(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1181 size_t MSConstantTupleGetSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1182 MS_LOG(INFO) << "Get Tuple Constant size!";
1183 if (error == nullptr) {
1184 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1185 return 0;
1186 }
1187 if (res_mgr == nullptr || node == nullptr) {
1188 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1189 *error = RET_NULL_PTR;
1190 return 0;
1191 }
1192 try {
1193 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1194 MS_EXCEPTION_IF_NULL(node_impl);
1195 auto val = node_impl->value();
1196 MS_EXCEPTION_IF_NULL(val);
1197 auto val_tuple = val->cast<ValueTuplePtr>();
1198 auto tuple_size = val_tuple->size();
1199 *error = RET_OK;
1200 return tuple_size;
1201 } catch (const std::exception &e) {
1202 MS_LOG(ERROR) << "Get Tuple Constant size failed. Error info: " << e.what();
1203 *error = RET_ERROR;
1204 return 0;
1205 }
1206 }
1207
MSConstantTupleGetValueInt64(ResMgrHandle res_mgr,ConstNodeHandle node,int64_t vec[],size_t size)1208 STATUS MSConstantTupleGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, int64_t vec[], size_t size) {
1209 MS_LOG(INFO) << "Get Tuple Constant Value!";
1210 if (res_mgr == nullptr || node == nullptr || vec == nullptr) {
1211 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [vec] is nullptr.";
1212 return RET_NULL_PTR;
1213 }
1214 try {
1215 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1216 MS_EXCEPTION_IF_NULL(node_impl);
1217 auto val = node_impl->value();
1218 MS_EXCEPTION_IF_NULL(val);
1219 auto val_tuple = val->cast<ValueTuplePtr>();
1220 auto val_list = val_tuple->value();
1221 if (val_list.size() != size) {
1222 MS_LOG(ERROR) << "Invalid input vector length, it should be: " << val_list.size() << ", but got: " << size;
1223 return RET_ERROR;
1224 }
1225 for (size_t i = 0; i < size; i++) {
1226 auto val_imm = val_list[i]->cast<Int64ImmPtr>();
1227 vec[i] = val_imm->value();
1228 }
1229 return RET_OK;
1230 } catch (const std::exception &e) {
1231 MS_LOG(ERROR) << "Get String Constant value failed. Error info: " << e.what();
1232 return RET_ERROR;
1233 }
1234 }
1235
MSConstantTypeGetValue(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1236 DataTypeC MSConstantTypeGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1237 MS_LOG(INFO) << "Get Type Constant Value!";
1238 if (error == nullptr) {
1239 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1240 return MS_INVALID_TYPE;
1241 }
1242 if (res_mgr == nullptr || node == nullptr) {
1243 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1244 *error = RET_NULL_PTR;
1245 return MS_INVALID_TYPE;
1246 }
1247 try {
1248 auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1249 MS_EXCEPTION_IF_NULL(node_impl);
1250 auto val = node_impl->value();
1251 MS_EXCEPTION_IF_NULL(val);
1252 auto val_type = val->cast<TypePtr>();
1253 auto ret_val = static_cast<DataTypeC>(val_type->type_id());
1254 *error = RET_OK;
1255 return ret_val;
1256 } catch (const std::exception &e) {
1257 MS_LOG(ERROR) << "Get Type Constant value failed. Error info: " << e.what();
1258 *error = RET_ERROR;
1259 return MS_INVALID_TYPE;
1260 }
1261 }
1262
GetOpPrim(ResMgrHandle res_mgr,ConstNodeHandle node)1263 PrimitivePtr GetOpPrim(ResMgrHandle res_mgr, ConstNodeHandle node) {
1264 auto src_node = GetSrcPtr<CNodePtr>(res_mgr, node);
1265 auto node_input = src_node->input(0);
1266 if (node_input == nullptr) {
1267 MS_LOG(ERROR) << "The node's input is nullptr.";
1268 return nullptr;
1269 }
1270 auto prim_node = node_input->cast<ValueNodePtr>();
1271 if (prim_node == nullptr) {
1272 MS_LOG(ERROR) << "The node's input is with invalid type.";
1273 return nullptr;
1274 }
1275 auto node_value = prim_node->value();
1276 if (node_value == nullptr) {
1277 MS_LOG(ERROR) << "The node's value is nullptr.";
1278 return nullptr;
1279 }
1280 auto prim = node_value->cast<PrimitivePtr>();
1281 if (prim == nullptr) {
1282 MS_LOG(ERROR) << "The node's value is with invalid type.";
1283 return nullptr;
1284 }
1285 return prim;
1286 }
1287
MSOpSetAttrScalarFloat32(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,float value)1288 STATUS MSOpSetAttrScalarFloat32(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, float value) {
1289 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1290 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1291 return RET_NULL_PTR;
1292 }
1293 auto prim = GetOpPrim(res_mgr, op);
1294 if (prim == nullptr) {
1295 MS_LOG(ERROR) << "Get primitive node failed";
1296 return RET_NULL_PTR;
1297 }
1298 prim->set_attr(attr_name, mindspore::MakeValue(value));
1299 return RET_OK;
1300 }
1301
MSOpSetAttrScalarBool(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,bool value)1302 STATUS MSOpSetAttrScalarBool(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, bool value) {
1303 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1304 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1305 return RET_NULL_PTR;
1306 }
1307 auto prim = GetOpPrim(res_mgr, op);
1308 if (prim == nullptr) {
1309 MS_LOG(ERROR) << "Get primitive node failed";
1310 return RET_NULL_PTR;
1311 }
1312 prim->set_attr(attr_name, mindspore::MakeValue(value));
1313 return RET_OK;
1314 }
1315
MSOpSetAttrScalarInt32(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,int32_t value)1316 STATUS MSOpSetAttrScalarInt32(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int32_t value) {
1317 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1318 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1319 return RET_NULL_PTR;
1320 }
1321 auto prim = GetOpPrim(res_mgr, op);
1322 if (prim == nullptr) {
1323 MS_LOG(ERROR) << "Get primitive node failed";
1324 return RET_NULL_PTR;
1325 }
1326 prim->set_attr(attr_name, mindspore::MakeValue(value));
1327 return RET_OK;
1328 }
1329
MSOpSetAttrScalarInt64(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,int64_t value)1330 STATUS MSOpSetAttrScalarInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t value) {
1331 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1332 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1333 return RET_NULL_PTR;
1334 }
1335 auto prim = GetOpPrim(res_mgr, op);
1336 if (prim == nullptr) {
1337 MS_LOG(ERROR) << "Get primitive node failed";
1338 return RET_NULL_PTR;
1339 }
1340 prim->set_attr(attr_name, mindspore::MakeValue(value));
1341 return RET_OK;
1342 }
1343
MSOpSetAttrType(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,DataTypeC value)1344 STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value) {
1345 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1346 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1347 return RET_NULL_PTR;
1348 }
1349 auto prim = GetOpPrim(res_mgr, op);
1350 if (prim == nullptr) {
1351 MS_LOG(ERROR) << "Get primitive node failed";
1352 return RET_NULL_PTR;
1353 }
1354 auto cxx_type = mindspore::TypeId(value);
1355 prim->set_attr(attr_name, mindspore::TypeIdToType(cxx_type));
1356 return RET_OK;
1357 }
1358
MSOpSetAttrTypeArray(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,DataTypeC value[],size_t vec_size)1359 STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value[],
1360 size_t vec_size) {
1361 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1362 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1363 return RET_NULL_PTR;
1364 }
1365 auto prim = GetOpPrim(res_mgr, op);
1366 if (prim == nullptr) {
1367 MS_LOG(ERROR) << "Get primitive node failed";
1368 return RET_NULL_PTR;
1369 }
1370 std::vector<mindspore::ValuePtr> vec_value;
1371 mindspore::TypeId cxx_type;
1372 for (size_t i = 0; i < vec_size; i++) {
1373 cxx_type = mindspore::TypeId(value[i]);
1374 vec_value.push_back(mindspore::TypeIdToType(cxx_type));
1375 }
1376 prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1377 return RET_OK;
1378 }
1379
MSOpSetAttrArray(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,void * value,size_t vec_size,DataTypeC data_type)1380 STATUS MSOpSetAttrArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, void *value, size_t vec_size,
1381 DataTypeC data_type) {
1382 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr || value == nullptr) {
1383 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] or [value_vec] is nullptr.";
1384 return RET_NULL_PTR;
1385 }
1386 auto prim = GetOpPrim(res_mgr, op);
1387 if (prim == nullptr) {
1388 MS_LOG(ERROR) << "Get primitive node failed";
1389 return RET_NULL_PTR;
1390 }
1391
1392 switch (data_type) {
1393 case MS_BOOL: {
1394 std::vector<bool> vec_value(static_cast<bool *>(value), static_cast<bool *>(value) + vec_size);
1395 prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1396 break;
1397 }
1398 case MS_INT32: {
1399 std::vector<int32_t> vec_value(static_cast<int32_t *>(value), static_cast<int32_t *>(value) + vec_size);
1400 prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1401 break;
1402 }
1403 case MS_INT64: {
1404 std::vector<int64_t> vec_value(static_cast<int64_t *>(value), static_cast<int64_t *>(value) + vec_size);
1405 prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1406 break;
1407 }
1408 case MS_FLOAT32: {
1409 std::vector<float> vec_value(static_cast<float *>(value), static_cast<float *>(value) + vec_size);
1410 prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1411 break;
1412 }
1413 default:
1414 MS_LOG(ERROR) << "Unrecognized datatype w/ DataTypeC ID: " << data_type << " , Attribute name: " << attr_name
1415 << std::endl;
1416 return RET_ERROR;
1417 }
1418 return RET_OK;
1419 }
1420
MSOpSetAttrStringArray(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,const char * value[],size_t vec_size)1421 STATUS MSOpSetAttrStringArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, const char *value[],
1422 size_t vec_size) {
1423 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr || value == nullptr) {
1424 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] or [value_vec] is nullptr.";
1425 return RET_NULL_PTR;
1426 }
1427 auto prim = GetOpPrim(res_mgr, op);
1428 if (prim == nullptr) {
1429 MS_LOG(ERROR) << "Get primitive node failed";
1430 return RET_NULL_PTR;
1431 }
1432
1433 std::vector<mindspore::ValuePtr> vec_value;
1434 for (size_t i = 0; i < vec_size; i++) {
1435 vec_value.push_back(mindspore::MakeValue(value[i]));
1436 }
1437 prim->set_attr(attr_name, std::make_shared<mindspore::ValueList>(vec_value));
1438 return RET_OK;
1439 }
1440
MSOpSetAttrString(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,const char * value)1441 STATUS MSOpSetAttrString(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, const char *value) {
1442 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr || value == nullptr) {
1443 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] or [value_vec] is nullptr.";
1444 return RET_NULL_PTR;
1445 }
1446 auto prim = GetOpPrim(res_mgr, op);
1447 if (prim == nullptr) {
1448 MS_LOG(ERROR) << "Get primitive node failed";
1449 return RET_NULL_PTR;
1450 }
1451 std::string value_str(value);
1452 prim->set_attr(attr_name, mindspore::MakeValue(value_str));
1453 return RET_OK;
1454 }
1455
MSOpGetAttrScalarInt64(ResMgrHandle res_mgr,ConstNodeHandle op,const char * attr_name,STATUS * error)1456 int64_t MSOpGetAttrScalarInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, STATUS *error) {
1457 if (error == nullptr) {
1458 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1459 return 0;
1460 }
1461 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1462 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1463 *error = RET_NULL_PTR;
1464 return 0;
1465 }
1466 std::string attr_name_str(attr_name);
1467 try {
1468 auto prim = GetOpPrim(res_mgr, op);
1469 MS_EXCEPTION_IF_NULL(prim);
1470 auto value = prim->GetAttr(attr_name_str);
1471 auto value_int64 = value->cast<Int64ImmPtr>();
1472 MS_EXCEPTION_IF_NULL(value_int64);
1473 auto ret_val = value_int64->value();
1474 *error = RET_OK;
1475 return ret_val;
1476 } catch (const std::exception &e) {
1477 MS_LOG(ERROR) << " Get Attribute failed. Error info: " << e.what();
1478 *error = RET_ERROR;
1479 return 0;
1480 }
1481 }
1482
MSOpGetAttrArrayInt64(ResMgrHandle res_mgr,ConstNodeHandle op,const char * attr_name,int64_t values[],size_t value_num)1483 STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, int64_t values[],
1484 size_t value_num) {
1485 if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1486 MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1487 return RET_NULL_PTR;
1488 }
1489 std::string attr_name_str(attr_name);
1490 try {
1491 auto prim = GetOpPrim(res_mgr, op);
1492 MS_EXCEPTION_IF_NULL(prim);
1493 auto value = prim->GetAttr(attr_name_str);
1494 MS_EXCEPTION_IF_NULL(value);
1495 auto value_tuple = value->cast<ValueTuplePtr>();
1496 MS_EXCEPTION_IF_NULL(value_tuple);
1497 auto value_list = value_tuple->value();
1498 if (value_list.size() != value_num) {
1499 MS_LOG(ERROR) << "Invalid input vector length, it should be: " << value_list.size() << ", but got: " << value_num;
1500 return RET_ERROR;
1501 }
1502 for (size_t i = 0; i < value_num; i++) {
1503 auto val_imm = value_list[i]->cast<Int64ImmPtr>();
1504 values[i] = val_imm->value();
1505 }
1506 return RET_OK;
1507 } catch (const std::exception &e) {
1508 MS_LOG(ERROR) << "Get Attribute failed. Error info: " << e.what();
1509 return RET_ERROR;
1510 }
1511 }
1512
MSOpSetName(ResMgrHandle res_mgr,NodeHandle node,const char * name)1513 STATUS MSOpSetName(ResMgrHandle res_mgr, NodeHandle node, const char *name) {
1514 if (res_mgr == nullptr || node == nullptr || name == nullptr) {
1515 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [name] is nullptr.";
1516 return RET_NULL_PTR;
1517 }
1518 auto node_impl = GetSrcPtr<CNodePtr>(res_mgr, node);
1519 if (node_impl == nullptr) {
1520 MS_LOG(ERROR) << "Get source pointer failed. Please check whether the input node is an operator node.";
1521 return RET_ERROR;
1522 }
1523 node_impl->set_fullname_with_scope(name);
1524 return RET_OK;
1525 }
1526
MSNodeGetName(ResMgrHandle res_mgr,ConstNodeHandle node,char str_buf[],size_t str_len)1527 STATUS MSNodeGetName(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
1528 if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
1529 MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
1530 return RET_NULL_PTR;
1531 }
1532 auto node_impl = GetSrcPtr<AnfNodePtr>(res_mgr, node);
1533 if (node_impl == nullptr) {
1534 MS_LOG(ERROR) << "Get source pointer failed.";
1535 return RET_ERROR;
1536 }
1537 auto name = node_impl->fullname_with_scope();
1538 size_t valid_size = name.size() < str_len - 1 ? name.size() : str_len - 1;
1539 for (size_t i = 0; i < valid_size; i++) {
1540 str_buf[i] = name.c_str()[i];
1541 }
1542 str_buf[valid_size] = '\0';
1543 return RET_OK;
1544 }
1545
1546 // dynamic op / eager mode
GenerateInnerInfo(ResMgrHandle res_mgr,const char * op_type,TensorHandle const inputs[],size_t input_num,size_t output_num,const DynamicOpInfo & extra_info)1547 std::shared_ptr<InnerOpInfo> GenerateInnerInfo(ResMgrHandle res_mgr, const char *op_type, TensorHandle const inputs[],
1548 size_t input_num, size_t output_num, const DynamicOpInfo &extra_info) {
1549 MS_EXCEPTION_IF_NULL(op_type);
1550 MS_EXCEPTION_IF_NULL(inputs);
1551 std::vector<ValuePtr> src_inputs{};
1552 std::vector<ShapeVector> out_shapes{};
1553 std::vector<DataTypeC> out_dtypes{};
1554 std::vector<std::pair<std::string, ValuePtr>> attrs_pair{};
1555 for (size_t i = 0; i < input_num; i++) {
1556 auto input = GetSrcPtr<ValuePtr>(res_mgr, inputs[i]);
1557 if (input == nullptr) {
1558 MS_LOG(EXCEPTION) << "Invalid input. Index: " << i;
1559 }
1560 (void)src_inputs.emplace_back(input);
1561 }
1562 if (extra_info.output_shapes != nullptr && extra_info.output_dtypes != nullptr) {
1563 for (size_t i = 0; i < output_num; i++) {
1564 MS_EXCEPTION_IF_NULL(extra_info.output_dims);
1565 size_t dim = extra_info.output_dims[i];
1566 ShapeVector out_shape{};
1567 MS_EXCEPTION_IF_NULL(extra_info.output_shapes[i]);
1568 for (size_t j = 0; j < dim; j++) {
1569 (void)out_shape.emplace_back(extra_info.output_shapes[i][j]);
1570 }
1571 (void)out_shapes.emplace_back(out_shape);
1572 (void)out_dtypes.emplace_back(extra_info.output_dtypes[i]);
1573 }
1574 }
1575 for (size_t i = 0; i < extra_info.attr_num; i++) {
1576 MS_EXCEPTION_IF_NULL(extra_info.attr_names[i]);
1577 auto value = GetSrcPtr<ValuePtr>(res_mgr, extra_info.attr_values[i]);
1578 if (value == nullptr) {
1579 MS_LOG(ERROR) << "Get attribute's source pointer failed, attribute index: " << i;
1580 }
1581 (void)attrs_pair.emplace_back(std::make_pair(extra_info.attr_names[i], value));
1582 }
1583 return std::make_shared<InnerOpInfo>(op_type, src_inputs, out_shapes, out_dtypes, attrs_pair);
1584 }
1585
CheckExtraInfo(const DynamicOpInfo & extra_info)1586 STATUS CheckExtraInfo(const DynamicOpInfo &extra_info) {
1587 MS_ERROR_IF_TRUE_W_RET_N_LOG(extra_info.attr_num < 0, RET_ERROR, "The attr_num must be non-zero!");
1588 MS_ERROR_IF_TRUE_W_RET_N_LOG(
1589 extra_info.attr_num == 0 && (extra_info.attr_names != nullptr || extra_info.attr_values != nullptr), RET_ERROR,
1590 "The attr_name and attr_values must be nullptr if attr_num is 0!");
1591 MS_ERROR_IF_TRUE_W_RET_N_LOG(
1592 extra_info.attr_num != 0 && (extra_info.attr_names == nullptr || extra_info.attr_values == nullptr), RET_ERROR,
1593 "The attr_name and attr_values must be specified if attr_num is non-negative!");
1594 MS_ERROR_IF_TRUE_W_RET_N_LOG(extra_info.output_dims != nullptr && extra_info.output_shapes == nullptr, RET_ERROR,
1595 "The output_shapes must be not nullptr if output_dims is non-zero!");
1596 return RET_OK;
1597 }
1598
OpRunInfoSetInputs(ResMgrHandle res_mgr,TensorHandle const inputs[],size_t input_num,FrontendOpRunInfoPtr op_run_info)1599 STATUS OpRunInfoSetInputs(ResMgrHandle res_mgr, TensorHandle const inputs[], size_t input_num,
1600 FrontendOpRunInfoPtr op_run_info) {
1601 auto prim = op_run_info->op_grad_info->op_prim;
1602 MS_EXCEPTION_IF_NULL(prim);
1603 op_run_info->input_size = input_num;
1604 op_run_info->op_grad_info->input_value.resize(input_num);
1605 for (size_t i = 0; i < input_num; i++) {
1606 auto in_arg = GetSrcPtr<ValuePtr>(res_mgr, inputs[i]);
1607 if (in_arg == nullptr) {
1608 MS_LOG(ERROR) << "Invalid input. Index: " << i;
1609 return RET_ERROR;
1610 }
1611 op_run_info->op_grad_info->input_value[i] = in_arg;
1612 }
1613 return RET_OK;
1614 }
1615
DynamicOpInfer(size_t output_num,FrontendOpRunInfoPtr op_run_info,const DynamicOpInfo & extra_info)1616 STATUS DynamicOpInfer(size_t output_num, FrontendOpRunInfoPtr op_run_info, const DynamicOpInfo &extra_info) {
1617 MS_EXCEPTION_IF_NULL(op_run_info);
1618 // get abstract
1619 op_run_info->op_grad_info->input_abs.resize(op_run_info->input_size);
1620 for (size_t i = 0; i < op_run_info->input_size; ++i) {
1621 auto input_value = op_run_info->op_grad_info->input_value[i];
1622 op_run_info->op_grad_info->input_abs[i] = input_value->ToAbstract();
1623 }
1624 // do infer
1625 AbstractBasePtr out_abs = nullptr;
1626 auto prim = op_run_info->op_grad_info->op_prim;
1627 if (extra_info.output_shapes != nullptr && extra_info.output_dims != nullptr && extra_info.output_dtypes != nullptr) {
1628 auto shape = BuildShape(extra_info.output_shapes, extra_info.output_dims, output_num);
1629 auto type = BuildType(extra_info.output_dtypes, output_num);
1630 out_abs = BuildAbstract(shape, type);
1631 } else {
1632 MS_LOG(INFO) << "Output shapes and dtypes info is not specified completely, using inner infer.";
1633 prim->BeginRecordAddAttr();
1634 out_abs = OpInferShapeAndType(prim, op_run_info->op_grad_info->input_abs);
1635 prim->EndRecordAddAttr();
1636 }
1637 MS_EXCEPTION_IF_NULL(out_abs);
1638 op_run_info->base_op_run_info.abstract = out_abs;
1639 return RET_OK;
1640 }
1641
DynamicOpGetMindRTBackend(ResMgrHandle res_mgr,const string & cur_device_target,uint32_t device_id)1642 MindRTBackendPtr DynamicOpGetMindRTBackend(ResMgrHandle res_mgr, const string &cur_device_target, uint32_t device_id) {
1643 auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
1644 auto cached_backend = res_mgr_ptr->GetBackendFromCache(cur_device_target);
1645 if (cached_backend != nullptr) {
1646 return cached_backend;
1647 } else {
1648 std::lock_guard<std::mutex> guard(mindspore::pipeline::Resource::GetBackendInitMutex());
1649 auto backend = std::make_shared<mindspore::compile::MindRTBackend>("ms", cur_device_target, device_id);
1650 MS_EXCEPTION_IF_NULL(backend);
1651 res_mgr_ptr->CacheBackend(cur_device_target, backend);
1652 return backend;
1653 }
1654 }
1655
DynamicOpRun(ResMgrHandle res_mgr,const FrontendOpRunInfoPtr & op_run_info)1656 ValuePtr DynamicOpRun(ResMgrHandle res_mgr, const FrontendOpRunInfoPtr &op_run_info) {
1657 MS_LOG(DEBUG) << "DynamicOpRun start";
1658 MS_EXCEPTION_IF_NULL(op_run_info);
1659 auto ms_context = mindspore::MsContext::GetInstance();
1660 MS_EXCEPTION_IF_NULL(ms_context);
1661 auto device_id = ms_context->get_param<uint32_t>(mindspore::MS_CTX_DEVICE_ID);
1662 ms_context->set_param<bool>(mindspore::MS_CTX_ENABLE_PYNATIVE_INFER, true);
1663 mindspore::pynative::PyNativeAlgo::DataConvert::GetInputTensor(op_run_info, nullptr);
1664 auto backend_op_run_info = std::make_shared<mindspore::BackendOpRunInfo>(
1665 op_run_info->base_op_run_info, std::make_shared<mindspore::Primitive>(*op_run_info->op_grad_info->op_prim), true,
1666 false);
1667
1668 mindspore::VectorRef outputs;
1669 const auto &cur_mindrt_backend =
1670 DynamicOpGetMindRTBackend(res_mgr, op_run_info->base_op_run_info.device_target, device_id);
1671 MS_EXCEPTION_IF_NULL(cur_mindrt_backend);
1672 py::scoped_interpreter py_scope;
1673 if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
1674 mindspore::AnfAlgo::SetDynamicAttrToPrim(backend_op_run_info->op_prim);
1675 cur_mindrt_backend->RunOpDynamic(backend_op_run_info, &outputs);
1676 } else {
1677 cur_mindrt_backend->RunOp(backend_op_run_info, &outputs);
1678 }
1679
1680 if (op_run_info->base_op_run_info.has_dynamic_output) {
1681 op_run_info->base_op_run_info.abstract = backend_op_run_info->base_op_run_info.abstract;
1682 }
1683 bool is_out_sequence = (op_run_info->base_op_run_info.abstract == nullptr ||
1684 op_run_info->base_op_run_info.abstract->isa<mindspore::abstract::AbstractSequence>());
1685 const auto &result = mindspore::pynative::PyNativeAlgo::DataConvert::VectorRefToValue(
1686 outputs, op_run_info->requires_grad, is_out_sequence);
1687 ms_context->set_param<bool>(mindspore::MS_CTX_ENABLE_PYNATIVE_INFER, false);
1688 MS_LOG(DEBUG) << "DynamicOpRun end";
1689 return result;
1690 }
1691
MSRunOpWithInfo(ResMgrHandle res_mgr,const char * op_type,TensorHandle const inputs[],size_t input_num,TensorHandle outputs[],size_t output_num,DynamicOpInfo extra_info)1692 STATUS MSRunOpWithInfo(ResMgrHandle res_mgr, const char *op_type, TensorHandle const inputs[], size_t input_num,
1693 TensorHandle outputs[], size_t output_num, DynamicOpInfo extra_info) {
1694 MS_ERROR_IF_TRUE_W_RET_N_LOG(res_mgr == nullptr, RET_NULL_PTR, "Input Handle [res_mgr] is nullptr!");
1695 MS_ERROR_IF_TRUE_W_RET_N_LOG(inputs == nullptr, RET_NULL_PTR, "Input Handle [inputs] is nullptr!");
1696 MS_ERROR_IF_TRUE_W_RET_N_LOG(outputs == nullptr, RET_NULL_PTR, "Input Handle [outputs] is nullptr!");
1697 MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num == 0, RET_NULL_PTR, "Input [input_num] must be non-zero!");
1698 MS_ERROR_IF_TRUE_W_RET_N_LOG(output_num == 0, RET_NULL_PTR, "Input [output_num] must be non-zero!");
1699 MS_ERROR_IF_TRUE_W_RET_N_LOG(CheckExtraInfo(extra_info) != RET_OK, RET_NULL_PTR, "Input [extra_info] is invalid!");
1700 try {
1701 auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
1702 FrontendOpRunInfoPtr op_run_info = nullptr;
1703 auto op_info = GenerateInnerInfo(res_mgr, op_type, inputs, input_num, output_num, extra_info);
1704 auto cached_run_info = res_mgr_ptr->GetOpRunInfoFromCache(op_info);
1705 if (cached_run_info != nullptr) {
1706 op_run_info = cached_run_info;
1707 // set inputs
1708 auto ret = OpRunInfoSetInputs(res_mgr, inputs, input_num, op_run_info);
1709 if (ret != RET_OK) {
1710 MS_LOG(ERROR) << "Dynamic Op set inputs failed.";
1711 return RET_ERROR;
1712 }
1713 } else {
1714 // create op_run_info
1715 op_run_info = std::make_shared<mindspore::pynative::FrontendOpRunInfo>();
1716 op_run_info->base_op_run_info.op_name = op_type;
1717 op_run_info->requires_grad = false;
1718 auto ms_context = mindspore::MsContext::GetInstance();
1719 auto cur_target = ms_context->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
1720 op_run_info->base_op_run_info.device_target = cur_target;
1721 // create prim
1722 auto prim = std::make_shared<PrimitiveImpl>(op_type);
1723 op_run_info->op_grad_info->op_prim = prim;
1724 // set inputs
1725 bool is_dynamic_shape =
1726 op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.use_dynamic_shape_process;
1727 mindspore::pynative::PyNativeAlgo::Common::GetConstInputToAttr(prim, op_type, cur_target, is_dynamic_shape,
1728 &op_run_info->input_to_attr);
1729 auto ret = OpRunInfoSetInputs(res_mgr, inputs, input_num, op_run_info);
1730 if (ret != RET_OK) {
1731 MS_LOG(ERROR) << "Dynamic Op set inputs failed.";
1732 return RET_ERROR;
1733 }
1734 // set args
1735 if (extra_info.attr_names != nullptr && extra_info.attr_values != nullptr) {
1736 ret = OpSetAttrs(res_mgr, prim, extra_info.attr_names, extra_info.attr_values, extra_info.attr_num);
1737 if (ret != RET_OK) {
1738 MS_LOG(ERROR) << "Dynamic Op set attributes failed.";
1739 return RET_ERROR;
1740 }
1741 }
1742 // infer and set abstract
1743 ret = DynamicOpInfer(output_num, op_run_info, extra_info);
1744 if (ret != RET_OK) {
1745 MS_LOG(ERROR) << "Dynamic Op infer shape and type failed.";
1746 return RET_ERROR;
1747 }
1748 // cache op run info
1749 res_mgr_ptr->CacheOpRunInfo(op_info, op_run_info);
1750 }
1751
1752 // run op
1753 op_run_info->real_out = DynamicOpRun(res_mgr, op_run_info);
1754 if (op_run_info->real_out->isa<ValueSequenceImpl>()) {
1755 const auto &result_v_list = op_run_info->real_out->cast<ValueSequencePtr>();
1756 if (result_v_list->size() == 1 && op_run_info->base_op_run_info.abstract != nullptr &&
1757 !op_run_info->base_op_run_info.abstract->isa<mindspore::abstract::AbstractSequence>()) {
1758 op_run_info->real_out = result_v_list->value().front();
1759 }
1760 }
1761
1762 // clear used input tensor
1763 op_run_info->base_op_run_info.expanded_input_values.clear();
1764 op_run_info->base_op_run_info.input_types.clear();
1765
1766 // get output tensor
1767 const std::vector<TensorPtr> &ref_outputs = ConvertOutputToTensor(op_run_info->real_out);
1768 if (ref_outputs.size() != output_num) {
1769 MS_LOG(ERROR) << "Invalid outputs number, it should be: " << ref_outputs.size() << ", but got: " << output_num;
1770 return RET_ERROR;
1771 }
1772 for (size_t i = 0; i < output_num; i++) {
1773 outputs[i] = GetRawPtr(res_mgr, ref_outputs[i]);
1774 }
1775 } catch (const std::exception &e) {
1776 MS_LOG(ERROR) << "Run op failed. Error info: " << e.what();
1777 return RET_ERROR;
1778 }
1779 return RET_OK;
1780 }
1781
MSRunOp(ResMgrHandle res_mgr,const char * op_type,TensorHandle const inputs[],size_t input_num,TensorHandle outputs[],size_t output_num)1782 STATUS MSRunOp(ResMgrHandle res_mgr, const char *op_type, TensorHandle const inputs[], size_t input_num,
1783 TensorHandle outputs[], size_t output_num) {
1784 MS_ERROR_IF_TRUE_W_RET_N_LOG(res_mgr == nullptr, RET_NULL_PTR, "Input Handle [res_mgr] is nullptr!");
1785 MS_ERROR_IF_TRUE_W_RET_N_LOG(inputs == nullptr, RET_NULL_PTR, "Input Handle [inputs] is nullptr!");
1786 MS_ERROR_IF_TRUE_W_RET_N_LOG(outputs == nullptr, RET_NULL_PTR, "Input Handle [outputs] is nullptr!");
1787 MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num == 0, RET_NULL_PTR, "Input [input_num] must be non-zero!");
1788 MS_ERROR_IF_TRUE_W_RET_N_LOG(output_num == 0, RET_NULL_PTR, "Input [output_num] must be non-zero!");
1789 DynamicOpInfo extra_info = {NULL, NULL, 0, NULL, NULL, NULL};
1790 return MSRunOpWithInfo(res_mgr, op_type, inputs, input_num, outputs, output_num, extra_info);
1791 }
1792