• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/c_api/ms/node.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "c_api/src/helper.h"
20 #include "c_api/src/common.h"
21 #include "c_api/src/utils.h"
22 #include "base/base.h"
23 #include "ir/param_info.h"
24 #include "ir/anf.h"
25 #include "ir/scope.h"
26 #include "ir/func_graph_cloner.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "kernel/oplib/oplib.h"
29 #include "kernel/oplib/opinfo.h"
30 #include "abstract/dshape.h"
31 #include "pipeline/pynative/base.h"
32 #include "pipeline/pynative/pynative_utils.h"
33 #include "mindspore/core/ops/other_ops.h"
34 
35 constexpr size_t firstInIdx = 1;
36 constexpr size_t secondInIdx = 2;
37 constexpr size_t switchInputNum = 3;
38 static const size_t maxMallocSize = GetMaxMallocSize();
MSNewOp(ResMgrHandle res_mgr,GraphHandle graph,const char * op_type,Handle const inputs[],size_t input_num,const char * const * attr_names,ValueHandle attrs[],size_t attr_num)39 NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, Handle const inputs[],
40                    size_t input_num, const char *const *attr_names, ValueHandle attrs[], size_t attr_num) {
41   if (res_mgr == nullptr || graph == nullptr || op_type == nullptr || inputs == nullptr) {
42     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [op_type] or [inputs] is nullptr.";
43     return nullptr;
44   }
45   // convert raw input pointer to source shared pointer
46   auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
47   if (res_fg == nullptr) {
48     MS_LOG(ERROR) << "Get source pointer failed.";
49     return nullptr;
50   }
51   auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
52   std::vector<AnfNodePtr> cnode_inputs{};
53   mindspore::AbstractBasePtrList abs_list{};
54   auto prim = std::make_shared<PrimitiveImpl>(op_type);
55   if (attr_names != nullptr && attrs != nullptr) {
56     auto ret = OpSetAttrs(res_mgr, prim, attr_names, attrs, attr_num);
57     if (ret != RET_OK) {
58       MS_LOG(ERROR) << "Op set attributes failed.";
59       return nullptr;
60     }
61   }
62   auto prim_node = mindspore::NewValueNode(prim);
63   cnode_inputs.push_back(prim_node);
64   CNodePtr cnode = nullptr;
65   try {
66     for (size_t i = 0; i < input_num; ++i) {
67       auto input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
68       MS_EXCEPTION_IF_NULL(input);
69       if (input->isa<ParameterImpl>() && input->func_graph() != res_fg) {
70         (void)res_fg->AddFreeVariable(input);
71       }
72       ConvertConstScalarInputToTensor(input);
73       cnode_inputs.push_back(input);
74       abs_list.push_back(input->abstract());
75     }
76     cnode = res_fg->NewCNodeInOrder(cnode_inputs);
77     MS_EXCEPTION_IF_NULL(cnode);
78     if (res_mgr_ptr->GetInfer()) {
79       auto out_abs = OpInferShapeAndType(prim, abs_list);
80       cnode->set_abstract(out_abs);
81     }
82   } catch (const std::exception &e) {
83     MS_LOG(ERROR) << "FuncGraph create new operator failed. Error info: " << e.what();
84     return nullptr;
85   }
86   MS_LOG(INFO) << "Add Operator" << op_type;
87   return GetRawPtr(res_mgr, cnode);
88 }
89 
MSPackNodesTuple(ResMgrHandle res_mgr,GraphHandle graph,Handle const nodes[],size_t node_num)90 NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, Handle const nodes[], size_t node_num) {
91   if (res_mgr == nullptr || graph == nullptr || nodes == nullptr) {
92     MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [nodes] is nullptr.";
93     return nullptr;
94   }
95   CNodePtr make_tuple_cnode = nullptr;
96   try {
97     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
98     MS_EXCEPTION_IF_NULL(res_fg);
99     std::vector<AnfNodePtr> in_nodes{NewValueNode(mindspore::prim::kPrimMakeTuple)};
100     mindspore::AbstractBasePtrList abs_list{};
101     for (size_t i = 0; i < node_num; ++i) {
102       auto in_node = GetSrcPtr<AnfNodePtr>(res_mgr, nodes[i]);
103       MS_EXCEPTION_IF_NULL(in_node);
104       in_nodes.push_back(in_node);
105       ConvertConstScalarInputToTensor(in_node);
106       abs_list.push_back(in_node->abstract());
107     }
108     make_tuple_cnode = res_fg->NewCNodeInOrder(in_nodes);
109     MS_EXCEPTION_IF_NULL(make_tuple_cnode);
110     make_tuple_cnode->set_abstract(std::make_shared<AbstractTupleImpl>(abs_list));
111   } catch (const std::exception &e) {
112     MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
113     return nullptr;
114   }
115   return GetRawPtr(res_mgr, make_tuple_cnode);
116 }
117 
MSOpGetSpecOutput(ResMgrHandle res_mgr,GraphHandle graph,ConstNodeHandle op,size_t i)118 NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op, size_t i) {
119   if (res_mgr == nullptr || graph == nullptr || op == nullptr) {
120     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
121     return nullptr;
122   }
123   CNodePtr ret_node = nullptr;
124   try {
125     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
126     MS_EXCEPTION_IF_NULL(res_fg);
127     auto cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
128     MS_EXCEPTION_IF_NULL(cnode);
129     auto abs = cnode->abstract();
130     if (abs == nullptr) {
131       MS_LOG(ERROR) << "Input op's abstract is nullptr!";
132       return nullptr;
133     }
134     if (abs->isa<mindspore::abstract::AbstractTuple>()) {
135       auto branch_num = abs->cast<mindspore::abstract::AbstractTuplePtr>()->size();
136       if (i >= branch_num) {
137         MS_LOG(ERROR) << "Invalid output branch index, it should be less than " << branch_num << ", but got: " << i;
138         return nullptr;
139       }
140       auto idx = mindspore::NewValueNode(mindspore::SizeToLong(i));
141       auto abs_scalar = std::make_shared<mindspore::abstract::AbstractScalar>(mindspore::SizeToInt(i));
142       idx->set_abstract(abs_scalar);
143       ret_node = res_fg->NewCNodeInOrder({NewValueNode(mindspore::prim::kPrimTupleGetItem), cnode, idx});
144       MS_EXCEPTION_IF_NULL(ret_node);
145       ret_node->set_abstract(abs->cast<mindspore::abstract::AbstractTuplePtr>()->elements()[i]);
146     } else {
147       if (i >= 1) {
148         MS_LOG(ERROR) << "Invalid output index. The op has only one output, so the output index should be 0, or you can"
149                          " directly use this op as the output without calling this function, but got: "
150                       << i;
151         return nullptr;
152       }
153       MS_LOG(WARNING) << "The op has only one output, you can directly use this op as the output without calling this "
154                          "function. Now the op itself is returned.";
155       ret_node = cnode;
156     }
157   } catch (const std::exception &e) {
158     MS_LOG(ERROR) << "FuncGraph get output failed. Error info: " << e.what();
159     return nullptr;
160   }
161   return GetRawPtr(res_mgr, ret_node);
162 }
163 
BuildSwitchStructure(ResMgrHandle res_mgr,GraphHandle graph,NodeHandle const switch_input[],size_t input_num,bool set_fg_out)164 CNodePtr BuildSwitchStructure(ResMgrHandle res_mgr, GraphHandle graph, NodeHandle const switch_input[],
165                               size_t input_num, bool set_fg_out) {
166   MS_EXCEPTION_IF_NULL(res_mgr);
167   MS_EXCEPTION_IF_NULL(graph);
168   MS_EXCEPTION_IF_NULL(switch_input);
169   MS_EXCEPTION_IF_CHECK_FAIL(input_num == switchInputNum, "Switch's input number must be 3!");
170   NodeHandle switch_op = MSNewOp(res_mgr, graph, "Switch", switch_input, input_num, NULL, NULL, 0);
171   if (switch_op == nullptr) {
172     MS_LOG(ERROR) << "Get Switch op failed!";
173     return nullptr;
174   }
175   auto src_switch = GetSrcPtr<CNodePtr>(res_mgr, switch_op);
176   MS_EXCEPTION_IF_NULL(src_switch);
177   auto fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
178   MS_EXCEPTION_IF_NULL(fg);
179   CNodePtr switch_call = fg->NewCNodeInOrder({src_switch});
180   MS_EXCEPTION_IF_NULL(switch_call);
181   if (set_fg_out) {
182     fg->set_output(switch_call);
183   }
184   auto first_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[firstInIdx]);
185   MS_EXCEPTION_IF_NULL(first_node);
186   auto second_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[secondInIdx]);
187   MS_EXCEPTION_IF_NULL(second_node);
188   // AddFuncGraphCNodeIndex is used to set cnode_index. A funcgraph's cnode_index is a list of pair
189   // with pair-struct is (CNODE, index). The CNODE is in another funcgraph, who uses the funcgraph as its input.
190   // for eg. if fg1's cnode A uses fg2 as A's first input, then fg2's conde_index is (A, 1)
191   if (first_node->isa<ValueNodeImpl>()) {
192     fg->AddValueNode(first_node);
193     if (mindspore::IsValueNode<FuncGraphImpl>(first_node)) {
194       auto used = mindspore::GetValueNode<FuncGraphPtr>(first_node);
195       used->AddFuncGraphCNodeIndex(
196         std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, firstInIdx + 1)));
197       (void)fg->AddFuncGraphUsed(used);
198     }
199   }
200   if (second_node->isa<ValueNodeImpl>()) {
201     fg->AddValueNode(second_node);
202     if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
203       auto used = mindspore::GetValueNode<FuncGraphPtr>(second_node);
204       used->AddFuncGraphCNodeIndex(
205         std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, secondInIdx + 1)));
206       (void)fg->AddFuncGraphUsed(used);
207     }
208   }
209   // Switch-call's abstract is equal to second branch.
210   if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
211     auto sub_fg = mindspore::GetValueNode<FuncGraphPtr>(second_node);
212     switch_call->set_abstract(sub_fg->output()->abstract());
213   }
214   return switch_call;
215 }
216 
MSNewSwitch(ResMgrHandle res_mgr,GraphHandle graph,Handle cond,ConstGraphHandle true_br,ConstGraphHandle false_br)217 NodeHandle MSNewSwitch(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, ConstGraphHandle true_br,
218                        ConstGraphHandle false_br) {
219   if (res_mgr == nullptr || graph == nullptr || cond == nullptr || true_br == nullptr || false_br == nullptr) {
220     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [true_br] or [false_br] is nullptr.";
221     return nullptr;
222   }
223   try {
224     auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
225     MS_EXCEPTION_IF_NULL(src_cond);
226     NodeHandle cond_raw_ptr = nullptr;
227     if (src_cond->isa<FuncGraphImpl>()) {
228       auto cond_graph = src_cond->cast<FuncGraphPtr>();
229       MS_EXCEPTION_IF_NULL(cond_graph);
230       auto cond_node = mindspore::NewValueNode(cond_graph);
231       cond_node->set_abstract(cond_graph->ToAbstract());
232       cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
233     } else {
234       cond_raw_ptr = cond;
235     }
236     auto true_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, true_br);
237     MS_EXCEPTION_IF_NULL(true_fg);
238     auto true_node = mindspore::NewValueNode(true_fg);
239     true_node->set_abstract(true_fg->ToAbstract());
240     NodeHandle true_raw_ptr = GetRawPtr(res_mgr, true_node);
241 
242     auto false_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, false_br);
243     MS_EXCEPTION_IF_NULL(false_fg);
244     auto false_node = mindspore::NewValueNode(false_fg);
245     false_node->set_abstract(false_fg->ToAbstract());
246     NodeHandle false_raw_ptr = GetRawPtr(res_mgr, false_node);
247 
248     NodeHandle switch_input[] = {cond_raw_ptr, true_raw_ptr, false_raw_ptr};
249     auto switch_call = BuildSwitchStructure(res_mgr, graph, switch_input, switchInputNum, false);
250     MS_EXCEPTION_IF_NULL(switch_call);
251     return GetRawPtr(res_mgr, switch_call);
252   } catch (const std::exception &e) {
253     MS_LOG(ERROR) << "New Switch node failed. Error info: " << e.what();
254     return nullptr;
255   }
256 }
257 
HandleFVInWhileGraph(const FuncGraphPtr & main_fg,const FuncGraphPtr & body_fg,const FuncGraphPtr & after_fg)258 void HandleFVInWhileGraph(const FuncGraphPtr &main_fg, const FuncGraphPtr &body_fg, const FuncGraphPtr &after_fg) {
259   std::vector<AnfNodePtr> fv_to_restore{};
260   auto body_fvs = body_fg->free_variables();
261   for (const auto &fv : body_fvs) {
262     auto fv_node = fv.first;
263     MS_EXCEPTION_IF_NULL(fv_node);
264     if (fv_node->func_graph() != main_fg &&
265         std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
266       fv_to_restore.push_back(fv_node);
267     }
268   }
269   auto after_fvs = after_fg->free_variables();
270   for (const auto &fv : after_fvs) {
271     auto fv_node = fv.first;
272     MS_EXCEPTION_IF_NULL(fv_node);
273     if (fv_node->func_graph() != main_fg &&
274         std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
275       fv_to_restore.push_back(fv_node);
276     }
277   }
278 
279   (void)mindspore::LiftingClone(main_fg);
280 
281   auto main_manager = Manage(main_fg);
282   std::vector<AnfNodePtr> new_main_params{};
283   auto main_params = main_fg->parameters();
284   for (const auto &main_param : main_params) {
285     auto src_main_param = main_param->cast<ParameterPtr>();
286     MS_EXCEPTION_IF_NULL(src_main_param);
287     auto found_in_fv_list =
288       find_if(fv_to_restore.begin(), fv_to_restore.end(), [&main_param](const AnfNodePtr &fv_param) {
289         return !main_param->ToString().empty() && main_param->ToString() == fv_param->ToString();
290       });
291     if (found_in_fv_list != fv_to_restore.end()) {
292       (void)main_manager->Replace(main_param, *found_in_fv_list);
293     } else if (src_main_param->has_default()) {
294       auto const_input = mindspore::NewValueNode(src_main_param->default_param());
295       const_input->set_abstract(src_main_param->abstract());
296       (void)main_manager->Replace(main_param, const_input);
297     } else {
298       new_main_params.push_back(main_param);
299     }
300   }
301   main_fg->set_parameters(new_main_params);
302 }
303 
MSNewWhile(ResMgrHandle res_mgr,GraphHandle graph,Handle cond,GraphHandle body_graph,GraphHandle after_graph)304 NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, GraphHandle body_graph,
305                       GraphHandle after_graph) {
306   if (res_mgr == nullptr || graph == nullptr || cond == nullptr || body_graph == nullptr || after_graph == nullptr) {
307     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [body_graph] or [after_graph] is nullptr.";
308     return nullptr;
309   }
310   try {
311     auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
312     MS_EXCEPTION_IF_NULL(src_cond);
313     NodeHandle cond_raw_ptr = nullptr;
314     GraphHandle cond_graph = nullptr;
315     FuncGraphPtr src_cond_graph = nullptr;
316     auto main_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
317     if (src_cond->isa<FuncGraphImpl>()) {
318       cond_graph = cond;
319       src_cond_graph = src_cond->cast<FuncGraphPtr>();
320       MS_EXCEPTION_IF_NULL(src_cond_graph);
321       auto cond_node = src_cond_graph->output();
322       MS_EXCEPTION_IF_NULL(cond_node);
323       cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
324     } else {
325       auto cond_fg = std::make_shared<FuncGraphImpl>();
326       MS_EXCEPTION_IF_NULL(cond_fg);
327       cond_graph = GetRawPtr(res_mgr, cond_fg);
328       MS_EXCEPTION_IF_NULL(cond_graph);
329       src_cond_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, cond_graph);
330       MS_EXCEPTION_IF_NULL(src_cond_graph);
331       (void)main_fg->AddFuncGraphUsed(src_cond_graph);
332       if (src_cond->isa<CNodeImpl>()) {
333         auto cond_node = src_cond->cast<CNodePtr>();
334         MS_EXCEPTION_IF_NULL(cond_node);
335         auto new_cond = src_cond_graph->NewCNodeInOrder(cond_node->inputs());
336         MS_EXCEPTION_IF_NULL(new_cond);
337         new_cond->set_abstract(cond_node->abstract());
338         cond_raw_ptr = GetRawPtr(res_mgr, new_cond);
339       } else {
340         cond_raw_ptr = cond;
341       }
342     }
343 
344     auto body_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, body_graph);
345     MS_EXCEPTION_IF_NULL(body_fg);
346     auto body_node = mindspore::NewValueNode(body_fg);
347     body_node->set_abstract(body_fg->ToAbstract());
348     NodeHandle body_raw_ptr = GetRawPtr(res_mgr, body_node);
349 
350     auto after_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, after_graph);
351     MS_EXCEPTION_IF_NULL(after_fg);
352     auto after_node = mindspore::NewValueNode(after_fg);
353     after_node->set_abstract(after_fg->ToAbstract());
354     NodeHandle after_raw_ptr = GetRawPtr(res_mgr, after_node);
355 
356     NodeHandle switch_input[] = {cond_raw_ptr, body_raw_ptr, after_raw_ptr};
357     (void)BuildSwitchStructure(res_mgr, cond_graph, switch_input, switchInputNum, true);
358 
359     // handle main graph call
360     NodeHandle main_func_call = MSNewFuncCallNode(res_mgr, graph, cond_graph, nullptr, 0);
361     auto src_call = GetSrcPtr<AnfNodePtr>(res_mgr, main_func_call);
362     main_fg->set_output(src_call);
363 
364     // handle free parameters in while graphs
365     HandleFVInWhileGraph(main_fg, body_fg, after_fg);
366 
367     // handle multi outputs in body graph
368     auto sub_graph_node = mindspore::NewValueNode(src_cond_graph);
369     sub_graph_node->set_abstract(src_cond_graph->ToAbstract());
370     std::vector<AnfNodePtr> sub_input_nodes{sub_graph_node};
371     auto body_out_node = body_fg->output();
372     MS_EXCEPTION_IF_NULL(body_out_node);
373     if (IsPrimitiveCNode(body_out_node, mindspore::prim::kPrimMakeTuple)) {
374       auto body_out_cnode = body_out_node->cast<CNodePtr>();
375       for (size_t i = 1; i < body_out_cnode->size(); i++) {
376         sub_input_nodes.push_back(body_out_cnode->input(i));
377       }
378     } else {
379       sub_input_nodes.push_back(body_out_node);
380     }
381     auto body_func_call = body_fg->NewCNodeInOrder(sub_input_nodes);
382     MS_EXCEPTION_IF_NULL(src_cond_graph->output());
383     MS_EXCEPTION_IF_NULL(body_func_call);
384     body_func_call->set_abstract(src_cond_graph->output()->abstract());
385     body_fg->set_output(body_func_call);
386     return main_func_call;
387   } catch (const std::exception &e) {
388     MS_LOG(ERROR) << "New While node failed. Error info: " << e.what();
389     return nullptr;
390   }
391 }
392 
CustomOpInferShape(const CustomOpInfo & info,const std::vector<AbstractBasePtr> & input_args)393 std::vector<BaseShapePtr> CustomOpInferShape(const CustomOpInfo &info, const std::vector<AbstractBasePtr> &input_args) {
394   auto dyn_arr_deleter = [](int64_t **x, size_t dims) {
395     std::for_each(x, x + dims, std::default_delete<int64_t[]>());
396     delete[] x;
397   };
398   if (info.output_shapes != nullptr) {
399     if (info.output_dims == nullptr) {
400       MS_LOG(ERROR) << "Output dims must be given if output shapes are specified!";
401       return {};
402     }
403     auto infer_shape = BuildShape(info.output_shapes, info.output_dims, info.output_num);
404     return infer_shape;
405   } else if (info.shape_infer_func != nullptr) {
406     size_t input_num = info.input_num;
407     size_t output_num = info.output_num;
408     MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num * sizeof(size_t) > maxMallocSize, {},
409                                  "The input_num is too large for memory allocation.");
410     MS_ERROR_IF_TRUE_W_RET_N_LOG(output_num * sizeof(size_t) > maxMallocSize, {},
411                                  "The output_num is too large for memory allocation.");
412     auto out_dims_arr = std::make_unique<size_t[]>(output_num);
413     std::unique_ptr<int64_t *, std::function<void(int64_t **)>> out_shapes_arr(
414       new (std::nothrow) int64_t *[output_num](), std::bind(dyn_arr_deleter, std::placeholders::_1, output_num));
415     for (size_t i = 0; i < output_num; i++) {
416       (out_shapes_arr.get())[i] = new int64_t[MAX_DIMS];
417     }
418     auto in_dims_arr = std::make_unique<size_t[]>(input_num);
419     std::unique_ptr<int64_t *, std::function<void(int64_t **)>> in_shapes_arr(
420       new (std::nothrow) int64_t *[input_num](), std::bind(dyn_arr_deleter, std::placeholders::_1, input_num));
421     for (size_t i = 0; i < input_num; i++) {
422       auto in_shape = input_args[i]->BuildShape();
423       MS_EXCEPTION_IF_NULL(in_shape);
424       auto in_shape_ptr = in_shape->cast<ShapePtr>();
425       MS_EXCEPTION_IF_NULL(in_shape_ptr);
426       auto in_shape_vec = in_shape_ptr->shape();
427       auto in_shape_dim = in_shape_vec.size();
428       in_dims_arr[i] = in_shape_dim;
429       MS_ERROR_IF_TRUE_W_RET_N_LOG(in_shape_dim * sizeof(size_t) > maxMallocSize, {},
430                                    "The in_shape_dim is too large for memory allocation.");
431       (in_shapes_arr.get())[i] = new int64_t[in_shape_dim];
432       for (size_t j = 0; j < in_shape_dim; j++) {
433         (in_shapes_arr.get())[i][j] = in_shape_vec[j];
434       }
435     }
436     auto ret = info.shape_infer_func(in_shapes_arr.get(), in_dims_arr.get(), input_num, out_shapes_arr.get(),
437                                      out_dims_arr.get(), output_num);
438     if (ret != RET_OK) {
439       MS_LOG(ERROR) << "Failed to call the shape infer function of custom op!";
440       return {};
441     }
442     auto infer_shape = BuildShape(out_shapes_arr.get(), out_dims_arr.get(), output_num);
443     return infer_shape;
444   } else {
445     MS_LOG(ERROR) << "Either output shape or output shape infer function must be specified!";
446     return {};
447   }
448 }
449 
CustomOpInferType(const CustomOpInfo & info,const std::vector<AbstractBasePtr> & input_args)450 std::vector<TypePtr> CustomOpInferType(const CustomOpInfo &info, const std::vector<AbstractBasePtr> &input_args) {
451   if (info.output_dtypes != nullptr) {
452     auto infer_dtype = BuildType(info.output_dtypes, info.output_num);
453     return infer_dtype;
454   } else if (info.shape_infer_func != nullptr) {
455     size_t input_num = info.input_num;
456     size_t output_num = info.output_num;
457     auto in_dtypes_arr = std::make_unique<DataTypeC[]>(input_num);
458     auto out_dtypes_arr = std::make_unique<DataTypeC[]>(output_num);
459     for (size_t i = 0; i < input_num; i++) {
460       auto in_type = input_args[i]->BuildType();
461       MS_EXCEPTION_IF_NULL(in_type);
462       auto real_type = in_type;
463       if (in_type->isa<TensorTypeImpl>()) {
464         auto tensor_type = in_type->cast<TensorTypePtr>();
465         real_type = tensor_type->element();
466       }
467       auto in_type_id = (enum DataTypeC)(real_type->type_id());
468       in_dtypes_arr[i] = in_type_id;
469     }
470     STATUS ret = info.dtype_infer_func(in_dtypes_arr.get(), input_num, out_dtypes_arr.get(), output_num);
471     if (ret != RET_OK) {
472       MS_LOG(ERROR) << "Failed to call the dtype infer function of custom op!";
473       return {};
474     }
475     auto infer_dtype = BuildType(out_dtypes_arr.get(), output_num);
476     return infer_dtype;
477   } else {
478     MS_LOG(ERROR) << "Either output dtype or output dtype infer function must be specified!";
479     return {};
480   }
481 }
482 
MSNewCustomOp(ResMgrHandle res_mgr,GraphHandle graph,Handle const inputs[],size_t input_num,CustomOpInfo info)483 NodeHandle MSNewCustomOp(ResMgrHandle res_mgr, GraphHandle graph, Handle const inputs[], size_t input_num,
484                          CustomOpInfo info) {
485   if (res_mgr == nullptr || graph == nullptr) {
486     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
487     return nullptr;
488   }
489   MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num != info.input_num, nullptr,
490                                "Input node number is not matched with the input number specified in custom op info.");
491   auto ret = CheckCustomOpInfo(info);
492   MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Invalid custom op info.");
493   try {
494     auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
495     auto org_infer = res_mgr_ptr->GetInfer();
496     res_mgr_ptr->SetInfer(false);
497     NodeHandle custom_op =
498       MSNewOp(res_mgr, graph, "Custom", inputs, info.input_num, info.attr_names, info.attr_values, info.attr_num);
499     MS_ERROR_IF_TRUE_W_RET_N_LOG(custom_op == nullptr, nullptr, "Create Custom op failed!");
500     res_mgr_ptr->SetInfer(org_infer);
501     // Supplement necessary attributes
502     ret = MSOpSetAttrString(res_mgr, custom_op, mindspore::kAttrFuncType, info.func_type);
503     MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Custom op set func type attribute failed.");
504     ret = MSOpSetAttrString(res_mgr, custom_op, mindspore::kAttrFuncName, info.func_name);
505     MS_ERROR_IF_TRUE_W_RET_N_LOG(ret != RET_OK, nullptr, "Custom op set func name attribute failed.");
506     // Build json object
507     nlohmann::json json_obj = ConvertOpInfoToJson(info);
508     MS_ERROR_IF_TRUE_W_RET_N_LOG(json_obj.empty(), nullptr, "Failed to convert op info to json.");
509     // Create op info and set info map
510     auto op_name = json_obj.at(mindspore::kernel::kOpName).get<std::string>();
511     auto imply_type = json_obj.at(mindspore::kernel::kImplyType).get<std::string>();
512     std::string func_name = info.func_name;
513     std::string target_name = info.target;
514     auto iter = mindspore::kernel::kImplyTypeStrToEnumMap.find(imply_type);
515     if (iter == mindspore::kernel::kImplyTypeStrToEnumMap.end()) {
516       MS_LOG(ERROR) << "Not support imply_type: " << imply_type;
517       return nullptr;
518     }
519     auto op_info = mindspore::kernel::OpLib::DecodeOpInfo(json_obj, iter->second, "");
520     if (op_info == nullptr) {
521       MS_LOG(ERROR) << "Decode op info failed: func_name: " << func_name << " imply_type " << imply_type;
522       return nullptr;
523     }
524     op_info->set_processor(imply_type);
525     auto key = op_name + imply_type;
526     auto &op_infos = mindspore::kernel::OpLib::GetOpInfoMap();
527     (void)op_infos[iter->second].insert(std::pair<std::string, mindspore::kernel::OpInfoPtr>(key, op_info));
528     // Infer shape and type
529     mindspore::AbstractBasePtrList abs_list{};
530     for (size_t i = 0; i < input_num; ++i) {
531       auto in_node = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
532       MS_EXCEPTION_IF_NULL(in_node);
533       abs_list.push_back(in_node->abstract());
534     }
535     auto infer_shape = CustomOpInferShape(info, abs_list);
536     auto infer_type = CustomOpInferType(info, abs_list);
537     AbstractBasePtr custom_abs = BuildAbstract(infer_shape, infer_type);
538     MS_EXCEPTION_IF_NULL(custom_abs);
539     auto src_op = GetSrcPtr<CNodePtr>(res_mgr, custom_op);
540     MS_EXCEPTION_IF_NULL(src_op);
541     src_op->set_abstract(custom_abs);
542     return custom_op;
543   } catch (const std::exception &e) {
544     MS_LOG(ERROR) << "Get custom op failed. Error info: " << e.what();
545     return nullptr;
546   }
547 }
548 
MSOpGetInput(ResMgrHandle res_mgr,ConstNodeHandle op,size_t i)549 NodeHandle MSOpGetInput(ResMgrHandle res_mgr, ConstNodeHandle op, size_t i) {
550   if (res_mgr == nullptr || op == nullptr) {
551     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
552     return nullptr;
553   }
554   mindspore::AnfNodePtr anf_node = nullptr;
555   try {
556     auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
557     MS_EXCEPTION_IF_NULL(src_cnode);
558     if (i >= src_cnode->size() - 1) {
559       MS_LOG(ERROR) << "Invalid input index, it should be less than " << src_cnode->size() - 1 << ", but got: " << i;
560       return nullptr;
561     }
562     anf_node = src_cnode->input(i + 1);
563   } catch (const std::exception &e) {
564     MS_LOG(ERROR) << "Get input from CNode failed. Error info: " << e.what();
565     return nullptr;
566   }
567   return GetRawPtr(res_mgr, anf_node);
568 }
569 
MSOpGetInputsNum(ResMgrHandle res_mgr,ConstNodeHandle op,STATUS * error)570 size_t MSOpGetInputsNum(ResMgrHandle res_mgr, ConstNodeHandle op, STATUS *error) {
571   if (error == nullptr) {
572     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
573     return 0;
574   }
575   if (res_mgr == nullptr || op == nullptr) {
576     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
577     *error = RET_NULL_PTR;
578     return 0;
579   }
580   size_t input_num;
581   try {
582     auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
583     MS_EXCEPTION_IF_NULL(src_cnode);
584     input_num = src_cnode->size() - 1;
585   } catch (const std::exception &e) {
586     MS_LOG(ERROR) << "FuncGraph get input number failed. Error info: " << e.what();
587     *error = RET_ERROR;
588     return 0;
589   }
590   *error = RET_OK;
591   return input_num;
592 }
593 
MSOpGetInputs(ResMgrHandle res_mgr,ConstNodeHandle op,NodeHandle inputs[],size_t input_num)594 STATUS MSOpGetInputs(ResMgrHandle res_mgr, ConstNodeHandle op, NodeHandle inputs[], size_t input_num) {
595   if (res_mgr == nullptr || op == nullptr || inputs == nullptr) {
596     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [inputs] is nullptr.";
597     return RET_NULL_PTR;
598   }
599   try {
600     auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
601     MS_EXCEPTION_IF_NULL(src_cnode);
602     auto in_num = src_cnode->size() - 1;
603     if (in_num != input_num) {
604       MS_LOG(ERROR) << "Invalid input number, it should be: " << in_num << ", but got: " << input_num;
605       return RET_ERROR;
606     }
607     auto cnode_inputs = src_cnode->inputs();
608     for (size_t i = 0; i < input_num; i++) {
609       inputs[i] = GetRawPtr(res_mgr, cnode_inputs[i + 1]);
610     }
611   } catch (const std::exception &e) {
612     MS_LOG(ERROR) << "Get inputs from CNode failed. Error info: " << e.what();
613     return RET_ERROR;
614   }
615   return RET_OK;
616 }
617 
MSOpGetOutputDimension(ResMgrHandle res_mgr,ConstNodeHandle op,size_t output_index,STATUS * ret)618 size_t MSOpGetOutputDimension(ResMgrHandle res_mgr, ConstNodeHandle op, size_t output_index, STATUS *ret) {
619   if (res_mgr == nullptr || op == nullptr) {
620     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
621     *ret = RET_NULL_PTR;
622     return 0;
623   }
624   try {
625     auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
626     MS_EXCEPTION_IF_NULL(src_cnode);
627     std::vector<int64_t> shape = mindspore::common::AnfAlgo::GetOutputInferShape(src_cnode, output_index);
628     return shape.size();
629   } catch (const std::exception &e) {
630     MS_LOG(ERROR) << "Get Shape from OP/CNode failed. Error info: " << e.what();
631     *ret = RET_ERROR;
632     return 0;
633   }
634 }
635 
MSOpGetOutputShape(ResMgrHandle res_mgr,ConstNodeHandle op,int64_t shape_ret[],size_t dim,size_t output_index)636 STATUS MSOpGetOutputShape(ResMgrHandle res_mgr, ConstNodeHandle op, int64_t shape_ret[], size_t dim,
637                           size_t output_index) {
638   if (res_mgr == nullptr || op == nullptr) {
639     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
640     return RET_NULL_PTR;
641   }
642   try {
643     auto src_cnode = GetSrcPtr<CNodePtr>(res_mgr, op);
644     MS_EXCEPTION_IF_NULL(src_cnode);
645     std::vector<int64_t> shape = mindspore::common::AnfAlgo::GetOutputInferShape(src_cnode, output_index);
646     MS_EXCEPTION_IF_CHECK_FAIL(
647       dim >= shape.size(),
648       "Input dimension less than the actual Dimension. Please ensure shape_ret have enough space.");
649     (void)std::copy(shape.begin(), shape.end(), shape_ret);
650   } catch (const std::exception &e) {
651     MS_LOG(ERROR) << "Get Shape from OP/CNode failed. Error info: " << e.what();
652     return RET_ERROR;
653   }
654   return RET_OK;
655 }
656 
MSNewFuncCallNode(ResMgrHandle res_mgr,GraphHandle graph,ConstGraphHandle sub_graph,Handle const inputs[],size_t input_num)657 NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph, ConstGraphHandle sub_graph, Handle const inputs[],
658                              size_t input_num) {
659   if (res_mgr == nullptr || graph == nullptr || sub_graph == nullptr) {
660     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [sub_graph] is nullptr.";
661     return nullptr;
662   }
663   CNodePtr cnode = nullptr;
664   try {
665     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
666     MS_EXCEPTION_IF_NULL(res_fg);
667     auto res_sub_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, sub_graph);
668     MS_EXCEPTION_IF_NULL(res_sub_fg);
669     auto sub_node = mindspore::NewValueNode(res_sub_fg);
670     sub_node->set_abstract(res_sub_fg->ToAbstract());
671     std::vector<AnfNodePtr> cnode_inputs{sub_node};
672     for (size_t i = 0; i < input_num; ++i) {
673       auto cnode_input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
674       MS_EXCEPTION_IF_NULL(cnode_input);
675       cnode_inputs.push_back(cnode_input);
676     }
677     cnode = res_fg->NewCNodeInOrder(cnode_inputs);
678     MS_EXCEPTION_IF_NULL(res_sub_fg->output());
679     cnode->set_abstract(res_sub_fg->output()->abstract());
680     (void)res_fg->AddFuncGraphUsed(res_sub_fg);
681   } catch (const std::exception &e) {
682     MS_LOG(ERROR) << "FuncGraph create SubGraph node failed. Error info: " << e.what();
683     return nullptr;
684   }
685   MS_LOG(INFO) << "Add function call node";
686   return GetRawPtr(res_mgr, cnode);
687 }
688 
MSNewPlaceholder(ResMgrHandle res_mgr,GraphHandle graph,DataTypeC type,const int64_t shape[],size_t shape_size)689 NodeHandle MSNewPlaceholder(ResMgrHandle res_mgr, GraphHandle graph, DataTypeC type, const int64_t shape[],
690                             size_t shape_size) {
691   if (res_mgr == nullptr || graph == nullptr) {
692     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
693     return nullptr;
694   }
695   ParameterPtr param = nullptr;
696   try {
697     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
698     MS_EXCEPTION_IF_NULL(res_fg);
699     param = res_fg->add_parameter();
700     auto type_ptr = mindspore::TypeIdToType(mindspore::TypeId(type));
701     AbstractBasePtr abs = GetAbstract(type_ptr, shape, shape_size, true);
702     param->set_abstract(abs);
703   } catch (const std::exception &e) {
704     MS_LOG(ERROR) << "FuncGraph add parameter failed. Error info: " << e.what();
705     return nullptr;
706   }
707   return GetRawPtr(res_mgr, param);
708 }
709 
MSNewVariableScalarFloat32(ResMgrHandle res_mgr,GraphHandle graph,float value)710 NodeHandle MSNewVariableScalarFloat32(ResMgrHandle res_mgr, GraphHandle graph, float value) {
711   if (res_mgr == nullptr || graph == nullptr) {
712     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
713     return nullptr;
714   }
715   ParameterPtr param = nullptr;
716   try {
717     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
718     MS_EXCEPTION_IF_NULL(res_fg);
719     param = GetScalarParam<float>(res_fg, value, mindspore::kNumberTypeFloat32);
720     MS_EXCEPTION_IF_NULL(param);
721   } catch (const std::exception &e) {
722     MS_LOG(ERROR) << "New Scalar Variable failed. Error info: " << e.what();
723     return nullptr;
724   }
725   return GetRawPtr(res_mgr, param);
726 }
727 
MSNewVariableScalarInt32(ResMgrHandle res_mgr,GraphHandle graph,int value)728 NodeHandle MSNewVariableScalarInt32(ResMgrHandle res_mgr, GraphHandle graph, int value) {
729   if (res_mgr == nullptr || graph == nullptr) {
730     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
731     return nullptr;
732   }
733   ParameterPtr param = nullptr;
734   try {
735     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
736     MS_EXCEPTION_IF_NULL(res_fg);
737     param = GetScalarParam<float>(res_fg, value, mindspore::kNumberTypeInt32);
738     MS_EXCEPTION_IF_NULL(param);
739   } catch (const std::exception &e) {
740     MS_LOG(ERROR) << "New Scalar Variable failed. Error info: " << e.what();
741     return nullptr;
742   }
743   return GetRawPtr(res_mgr, param);
744 }
745 
MSNewVariableArray(ResMgrHandle res_mgr,GraphHandle graph,void * data,DataTypeC type,const int64_t shape[],size_t shape_size,size_t data_len)746 NodeHandle MSNewVariableArray(ResMgrHandle res_mgr, GraphHandle graph, void *data, DataTypeC type,
747                               const int64_t shape[], size_t shape_size, size_t data_len) {
748   if (res_mgr == nullptr || graph == nullptr || data == nullptr || shape == nullptr) {
749     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [data] or [shape] is nullptr.";
750     return nullptr;
751   }
752   ParameterPtr param = nullptr;
753   ShapeVector shape_vec(shape, shape + shape_size);
754   try {
755     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
756     MS_EXCEPTION_IF_NULL(res_fg);
757     param = res_fg->add_parameter();
758     auto tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data, data_len);
759     tensor->set_param_info(std::make_shared<mindspore::ParamInfo>());
760     param->set_abstract(tensor->ToAbstract());
761     param->set_default_param(tensor);
762   } catch (const std::exception &e) {
763     MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
764     return nullptr;
765   }
766   return GetRawPtr(res_mgr, param);
767 }
768 
MSNewVariableFromTensor(ResMgrHandle res_mgr,GraphHandle graph,ConstTensorHandle tensor)769 NodeHandle MSNewVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, ConstTensorHandle tensor) {
770   if (res_mgr == nullptr || graph == nullptr || tensor == nullptr) {
771     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [tensor] is nullptr.";
772     return nullptr;
773   }
774   ParameterPtr param = nullptr;
775   try {
776     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
777     MS_EXCEPTION_IF_NULL(res_fg);
778     auto tensor_impl = GetSrcPtr<TensorPtr>(res_mgr, tensor);
779     MS_EXCEPTION_IF_NULL(tensor_impl);
780     param = res_fg->add_parameter();
781     param->set_abstract(tensor_impl->ToAbstract());
782     param->set_default_param(tensor_impl);
783   } catch (const std::exception &e) {
784     MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
785     return nullptr;
786   }
787   return GetRawPtr(res_mgr, param);
788 }
789 
MSVariableArrayGetDataSize(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)790 size_t MSVariableArrayGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
791   if (error == nullptr) {
792     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
793     return 0;
794   }
795   if (res_mgr == nullptr || node == nullptr) {
796     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
797     *error = RET_NULL_PTR;
798     return 0;
799   }
800   try {
801     auto node_impl = GetSrcPtr<ParameterPtr>(res_mgr, node);
802     MS_EXCEPTION_IF_NULL(node_impl);
803     auto val = node_impl->default_param();
804     MS_EXCEPTION_IF_NULL(val);
805     auto tensor = val->cast<TensorPtr>();
806     MS_EXCEPTION_IF_NULL(tensor);
807     size_t data_size = tensor->Size();
808     *error = RET_OK;
809     return data_size;
810   } catch (const std::exception &e) {
811     MS_LOG(ERROR) << "Tensor Variable get data failed. Error info: " << e.what();
812     *error = RET_ERROR;
813     return 0;
814   }
815 }
816 
MSVariableArrayGetData(ResMgrHandle res_mgr,ConstNodeHandle node)817 void *MSVariableArrayGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
818   if (res_mgr == nullptr || node == nullptr) {
819     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
820     return nullptr;
821   }
822   try {
823     auto node_impl = GetSrcPtr<ParameterPtr>(res_mgr, node);
824     MS_EXCEPTION_IF_NULL(node_impl);
825     auto val = node_impl->default_param();
826     MS_EXCEPTION_IF_NULL(val);
827     auto tensor = val->cast<TensorPtr>();
828     MS_EXCEPTION_IF_NULL(tensor);
829     void *data = tensor->data_c();
830     return data;
831   } catch (const std::exception &e) {
832     MS_LOG(ERROR) << "Tensor Variable get data failed. Error info: " << e.what();
833     return nullptr;
834   }
835 }
836 
MSNewConstantArray(ResMgrHandle res_mgr,void * data,DataTypeC type,const int64_t shape[],size_t shape_size,size_t data_len)837 NodeHandle MSNewConstantArray(ResMgrHandle res_mgr, void *data, DataTypeC type, const int64_t shape[],
838                               size_t shape_size, size_t data_len) {
839   if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
840     MS_LOG(ERROR) << "Input Handle [res_mgr] or [data] or [shape] is nullptr.";
841     return nullptr;
842   }
843   ShapeVector shape_vec(shape, shape + shape_size);
844   ValueNodePtr value_node = nullptr;
845   try {
846     auto tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data, data_len);
847     tensor->set_param_info(std::make_shared<mindspore::ParamInfo>());
848     value_node = mindspore::NewValueNode(tensor);
849     value_node->set_abstract(tensor->ToAbstract());
850   } catch (const std::exception &e) {
851     MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
852     return nullptr;
853   }
854   return GetRawPtr(res_mgr, value_node);
855 }
856 
MSNewConstantFromTensor(ResMgrHandle res_mgr,TensorHandle tensor)857 NodeHandle MSNewConstantFromTensor(ResMgrHandle res_mgr, TensorHandle tensor) {
858   if (res_mgr == nullptr || tensor == nullptr) {
859     MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
860     return nullptr;
861   }
862   ValueNodePtr value_node = nullptr;
863   try {
864     auto tensor_impl = GetSrcPtr<TensorPtr>(res_mgr, tensor);
865     MS_EXCEPTION_IF_NULL(tensor_impl);
866     value_node = mindspore::NewValueNode(tensor_impl);
867     value_node->set_abstract(tensor_impl->ToAbstract());
868   } catch (const std::exception &e) {
869     MS_LOG(ERROR) << "New Tensor Variable failed. Error info: " << e.what();
870     return nullptr;
871   }
872   return GetRawPtr(res_mgr, value_node);
873 }
874 
MSConstantArrayGetDataSize(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)875 size_t MSConstantArrayGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
876   if (error == nullptr) {
877     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
878     return 0;
879   }
880   if (res_mgr == nullptr || node == nullptr) {
881     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
882     *error = RET_NULL_PTR;
883     return 0;
884   }
885   try {
886     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
887     MS_EXCEPTION_IF_NULL(node_impl);
888     auto val = node_impl->value();
889     MS_EXCEPTION_IF_NULL(val);
890     auto tensor = val->cast<TensorPtr>();
891     MS_EXCEPTION_IF_NULL(tensor);
892     size_t data_size = tensor->Size();
893     *error = RET_OK;
894     return data_size;
895   } catch (const std::exception &e) {
896     MS_LOG(ERROR) << "Tensor Constant get data failed. Error info: " << e.what();
897     *error = RET_ERROR;
898     return 0;
899   }
900 }
901 
MSConstantArrayGetData(ResMgrHandle res_mgr,ConstNodeHandle node)902 void *MSConstantArrayGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
903   if (res_mgr == nullptr || node == nullptr) {
904     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
905     return nullptr;
906   }
907   try {
908     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
909     MS_EXCEPTION_IF_NULL(node_impl);
910     auto val = node_impl->value();
911     MS_EXCEPTION_IF_NULL(val);
912     auto tensor = val->cast<TensorPtr>();
913     MS_EXCEPTION_IF_NULL(tensor);
914     void *data = tensor->data_c();
915     return data;
916   } catch (const std::exception &e) {
917     MS_LOG(ERROR) << "Tensor Constant get data failed. Error info: " << e.what();
918     return nullptr;
919   }
920 }
921 
MSNewConstantScalarFloat32(ResMgrHandle res_mgr,float value)922 NodeHandle MSNewConstantScalarFloat32(ResMgrHandle res_mgr, float value) {
923   MS_LOG(INFO) << "New Float32 Scalar Value!s";
924   if (res_mgr == nullptr) {
925     MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
926     return nullptr;
927   }
928   auto value_node = mindspore::NewValueNode(value);
929   value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
930   return GetRawPtr(res_mgr, value_node);
931 }
932 
MSNewConstantScalarBool(ResMgrHandle res_mgr,bool value)933 NodeHandle MSNewConstantScalarBool(ResMgrHandle res_mgr, bool value) {
934   MS_LOG(INFO) << "New Bool Scalar Value!";
935   if (res_mgr == nullptr) {
936     MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
937     return nullptr;
938   }
939   auto value_node = mindspore::NewValueNode(value);
940   value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
941   return GetRawPtr(res_mgr, value_node);
942 }
943 
MSNewConstantScalarInt32(ResMgrHandle res_mgr,int value)944 NodeHandle MSNewConstantScalarInt32(ResMgrHandle res_mgr, int value) {
945   MS_LOG(INFO) << "New Int32 Scalar Value!";
946   if (res_mgr == nullptr) {
947     MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
948     return nullptr;
949   }
950   auto value_node = mindspore::NewValueNode(value);
951   value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
952   return GetRawPtr(res_mgr, value_node);
953 }
954 
MSNewConstantScalarInt64(ResMgrHandle res_mgr,int64_t value)955 NodeHandle MSNewConstantScalarInt64(ResMgrHandle res_mgr, int64_t value) {
956   MS_LOG(INFO) << "New Int64 Scalar Value!";
957   if (res_mgr == nullptr) {
958     MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
959     return nullptr;
960   }
961   auto value_node = mindspore::NewValueNode(value);
962   value_node->set_abstract(std::make_shared<AbstractScalarImpl>(value));
963   return GetRawPtr(res_mgr, value_node);
964 }
965 
MSNewConstantString(ResMgrHandle res_mgr,const char * str)966 NodeHandle MSNewConstantString(ResMgrHandle res_mgr, const char *str) {
967   MS_LOG(INFO) << "New String Scalar Value!";
968   if (res_mgr == nullptr || str == nullptr) {
969     MS_LOG(ERROR) << "Input Handle [res_mgr] or [str] is nullptr.";
970     return nullptr;
971   }
972   string str_val(str);
973   auto value_node = mindspore::NewValueNode(str_val);
974   value_node->set_abstract(std::make_shared<AbstractScalarImpl>(str_val));
975   return GetRawPtr(res_mgr, value_node);
976 }
977 
MSNewConstantTupleInt64(ResMgrHandle res_mgr,const int64_t vec[],size_t size)978 NodeHandle MSNewConstantTupleInt64(ResMgrHandle res_mgr, const int64_t vec[], size_t size) {
979   MS_LOG(INFO) << "New Vector Value!";
980   if (res_mgr == nullptr || vec == nullptr) {
981     MS_LOG(ERROR) << "Input Handle [res_mgr] or [vec] is nullptr.";
982     return nullptr;
983   }
984   auto value_node = mindspore::NewValueNode(std::vector<int64_t>(vec, vec + size));
985   mindspore::AbstractBasePtrList abs_list = {};
986   for (size_t i = 0; i < size; i++) {
987     AbstractBasePtr base = std::make_shared<AbstractScalarImpl>(vec[i]);
988     abs_list.push_back(base);
989   }
990   auto abstract = std::make_shared<AbstractTupleImpl>(abs_list);
991   value_node->set_abstract(abstract);
992   return GetRawPtr(res_mgr, value_node);
993 }
994 
MSNewConstantType(ResMgrHandle res_mgr,DataTypeC type)995 NodeHandle MSNewConstantType(ResMgrHandle res_mgr, DataTypeC type) {
996   MS_LOG(INFO) << "New Type Value: " << type;
997   if (res_mgr == nullptr) {
998     MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
999     return nullptr;
1000   }
1001   auto type_ptr = mindspore::TypeIdToType(mindspore::TypeId(type));
1002   auto value_node = mindspore::NewValueNode(type_ptr);
1003   auto abstract = std::make_shared<AbstractTypeImpl>(type_ptr);
1004   value_node->set_abstract(abstract);
1005   return GetRawPtr(res_mgr, value_node);
1006 }
1007 
MSConstantScalarGetValueInt32(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1008 int MSConstantScalarGetValueInt32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1009   MS_LOG(INFO) << "Get Int32 Scalar Value!";
1010   if (error == nullptr) {
1011     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1012     return 0;
1013   }
1014   if (res_mgr == nullptr || node == nullptr) {
1015     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1016     *error = RET_NULL_PTR;
1017     return 0;
1018   }
1019   int ret_val = 0;
1020   *error = RET_OK;
1021   try {
1022     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1023     MS_EXCEPTION_IF_NULL(node_impl);
1024     auto val = node_impl->value();
1025     MS_EXCEPTION_IF_NULL(val);
1026     if (val->isa<TensorImpl>()) {
1027       auto val_tensor = val->cast<TensorPtr>();
1028       auto data = val_tensor->data_c();
1029       MS_EXCEPTION_IF_NULL(data);
1030       ret_val = static_cast<int *>(data)[0];
1031     } else if (val->isa<Int32ImmImpl>()) {
1032       auto val_imm = val->cast<Int32ImmPtr>();
1033       ret_val = val_imm->value();
1034     } else {
1035       MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1036       *error = RET_ERROR;
1037     }
1038   } catch (const std::exception &e) {
1039     MS_LOG(ERROR) << "Get Int32 Scalar value failed. Error info: " << e.what();
1040     *error = RET_ERROR;
1041   }
1042   return ret_val;
1043 }
1044 
MSConstantScalarGetValueFloat32(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1045 float MSConstantScalarGetValueFloat32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1046   MS_LOG(INFO) << "Get Float32 Scalar Value!";
1047   if (error == nullptr) {
1048     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1049     return 0;
1050   }
1051   if (res_mgr == nullptr || node == nullptr) {
1052     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1053     *error = RET_NULL_PTR;
1054     return 0;
1055   }
1056   float ret_val = 0;
1057   *error = RET_OK;
1058   try {
1059     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1060     MS_EXCEPTION_IF_NULL(node_impl);
1061     auto val = node_impl->value();
1062     MS_EXCEPTION_IF_NULL(val);
1063     if (val->isa<TensorImpl>()) {
1064       auto val_tensor = val->cast<TensorPtr>();
1065       auto data = val_tensor->data_c();
1066       MS_EXCEPTION_IF_NULL(data);
1067       ret_val = static_cast<float *>(data)[0];
1068     } else if (val->isa<Float32ImmImpl>()) {
1069       auto val_imm = val->cast<Float32ImmPtr>();
1070       ret_val = val_imm->value();
1071     } else {
1072       MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1073       *error = RET_ERROR;
1074     }
1075   } catch (const std::exception &e) {
1076     MS_LOG(ERROR) << "Get Float32 Scalar value failed. Error info: " << e.what();
1077     *error = RET_ERROR;
1078   }
1079   return ret_val;
1080 }
1081 
MSConstantScalarGetValueBool(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1082 bool MSConstantScalarGetValueBool(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1083   MS_LOG(INFO) << "Get Bool Scalar Value!";
1084   if (error == nullptr) {
1085     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1086     return false;
1087   }
1088   if (res_mgr == nullptr || node == nullptr) {
1089     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1090     *error = RET_NULL_PTR;
1091     return false;
1092   }
1093   int ret_val = false;
1094   *error = RET_OK;
1095   try {
1096     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1097     MS_EXCEPTION_IF_NULL(node_impl);
1098     auto val = node_impl->value();
1099     MS_EXCEPTION_IF_NULL(val);
1100     if (val->isa<TensorImpl>()) {
1101       auto val_tensor = val->cast<TensorPtr>();
1102       auto data = val_tensor->data_c();
1103       MS_EXCEPTION_IF_NULL(data);
1104       ret_val = static_cast<bool *>(data)[0];
1105     } else if (val->isa<BoolImmImpl>()) {
1106       auto val_imm = val->cast<BoolImmPtr>();
1107       ret_val = val_imm->value();
1108     } else {
1109       MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1110       *error = RET_ERROR;
1111     }
1112   } catch (const std::exception &e) {
1113     MS_LOG(ERROR) << "Get Bool Scalar value failed. Error info: " << e.what();
1114     *error = RET_ERROR;
1115   }
1116   return ret_val;
1117 }
1118 
MSConstantScalarGetValueInt64(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1119 int64_t MSConstantScalarGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1120   MS_LOG(INFO) << "Get Int64 Scalar Value!";
1121   if (error == nullptr) {
1122     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1123     return 0;
1124   }
1125   if (res_mgr == nullptr || node == nullptr) {
1126     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1127     *error = RET_NULL_PTR;
1128     return 0;
1129   }
1130   int64_t ret_val = 0;
1131   *error = RET_OK;
1132   try {
1133     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1134     MS_EXCEPTION_IF_NULL(node_impl);
1135     auto val = node_impl->value();
1136     MS_EXCEPTION_IF_NULL(val);
1137     if (val->isa<TensorImpl>()) {
1138       auto val_tensor = val->cast<TensorPtr>();
1139       auto data = val_tensor->data_c();
1140       MS_EXCEPTION_IF_NULL(data);
1141       ret_val = static_cast<int64_t *>(data)[0];
1142     } else if (val->isa<Int64ImmImpl>()) {
1143       auto val_imm = val->cast<Int64ImmPtr>();
1144       ret_val = val_imm->value();
1145     } else {
1146       MS_LOG(ERROR) << "Input node has invalid value type: " << val->type_name();
1147       *error = RET_ERROR;
1148     }
1149   } catch (const std::exception &e) {
1150     MS_LOG(ERROR) << "Get Int64 Scalar value failed. Error info: " << e.what();
1151     *error = RET_ERROR;
1152   }
1153   return ret_val;
1154 }
1155 
MSConstantStringGetValue(ResMgrHandle res_mgr,ConstNodeHandle node,char str_buf[],size_t str_len)1156 STATUS MSConstantStringGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
1157   MS_LOG(INFO) << "Get String Constant Value!";
1158   if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
1159     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
1160     return RET_NULL_PTR;
1161   }
1162   try {
1163     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1164     MS_EXCEPTION_IF_NULL(node_impl);
1165     auto val = node_impl->value();
1166     MS_EXCEPTION_IF_NULL(val);
1167     auto val_str = val->cast<StringImmPtr>();
1168     std::string ret_val = val_str->value();
1169     size_t valid_size = ret_val.size() < str_len - 1 ? ret_val.size() : str_len - 1;
1170     for (size_t i = 0; i < valid_size; i++) {
1171       str_buf[i] = ret_val.c_str()[i];
1172     }
1173     str_buf[valid_size] = '\0';
1174     return RET_OK;
1175   } catch (const std::exception &e) {
1176     MS_LOG(ERROR) << "Get String Constant value failed. Error info: " << e.what();
1177     return RET_ERROR;
1178   }
1179 }
1180 
MSConstantTupleGetSize(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1181 size_t MSConstantTupleGetSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1182   MS_LOG(INFO) << "Get Tuple Constant size!";
1183   if (error == nullptr) {
1184     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1185     return 0;
1186   }
1187   if (res_mgr == nullptr || node == nullptr) {
1188     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1189     *error = RET_NULL_PTR;
1190     return 0;
1191   }
1192   try {
1193     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1194     MS_EXCEPTION_IF_NULL(node_impl);
1195     auto val = node_impl->value();
1196     MS_EXCEPTION_IF_NULL(val);
1197     auto val_tuple = val->cast<ValueTuplePtr>();
1198     auto tuple_size = val_tuple->size();
1199     *error = RET_OK;
1200     return tuple_size;
1201   } catch (const std::exception &e) {
1202     MS_LOG(ERROR) << "Get Tuple Constant size failed. Error info: " << e.what();
1203     *error = RET_ERROR;
1204     return 0;
1205   }
1206 }
1207 
MSConstantTupleGetValueInt64(ResMgrHandle res_mgr,ConstNodeHandle node,int64_t vec[],size_t size)1208 STATUS MSConstantTupleGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, int64_t vec[], size_t size) {
1209   MS_LOG(INFO) << "Get Tuple Constant Value!";
1210   if (res_mgr == nullptr || node == nullptr || vec == nullptr) {
1211     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [vec] is nullptr.";
1212     return RET_NULL_PTR;
1213   }
1214   try {
1215     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1216     MS_EXCEPTION_IF_NULL(node_impl);
1217     auto val = node_impl->value();
1218     MS_EXCEPTION_IF_NULL(val);
1219     auto val_tuple = val->cast<ValueTuplePtr>();
1220     auto val_list = val_tuple->value();
1221     if (val_list.size() != size) {
1222       MS_LOG(ERROR) << "Invalid input vector length, it should be: " << val_list.size() << ", but got: " << size;
1223       return RET_ERROR;
1224     }
1225     for (size_t i = 0; i < size; i++) {
1226       auto val_imm = val_list[i]->cast<Int64ImmPtr>();
1227       vec[i] = val_imm->value();
1228     }
1229     return RET_OK;
1230   } catch (const std::exception &e) {
1231     MS_LOG(ERROR) << "Get String Constant value failed. Error info: " << e.what();
1232     return RET_ERROR;
1233   }
1234 }
1235 
MSConstantTypeGetValue(ResMgrHandle res_mgr,ConstNodeHandle node,STATUS * error)1236 DataTypeC MSConstantTypeGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
1237   MS_LOG(INFO) << "Get Type Constant Value!";
1238   if (error == nullptr) {
1239     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1240     return MS_INVALID_TYPE;
1241   }
1242   if (res_mgr == nullptr || node == nullptr) {
1243     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
1244     *error = RET_NULL_PTR;
1245     return MS_INVALID_TYPE;
1246   }
1247   try {
1248     auto node_impl = GetSrcPtr<ValueNodePtr>(res_mgr, node);
1249     MS_EXCEPTION_IF_NULL(node_impl);
1250     auto val = node_impl->value();
1251     MS_EXCEPTION_IF_NULL(val);
1252     auto val_type = val->cast<TypePtr>();
1253     auto ret_val = static_cast<DataTypeC>(val_type->type_id());
1254     *error = RET_OK;
1255     return ret_val;
1256   } catch (const std::exception &e) {
1257     MS_LOG(ERROR) << "Get Type Constant value failed. Error info: " << e.what();
1258     *error = RET_ERROR;
1259     return MS_INVALID_TYPE;
1260   }
1261 }
1262 
GetOpPrim(ResMgrHandle res_mgr,ConstNodeHandle node)1263 PrimitivePtr GetOpPrim(ResMgrHandle res_mgr, ConstNodeHandle node) {
1264   auto src_node = GetSrcPtr<CNodePtr>(res_mgr, node);
1265   auto node_input = src_node->input(0);
1266   if (node_input == nullptr) {
1267     MS_LOG(ERROR) << "The node's input is nullptr.";
1268     return nullptr;
1269   }
1270   auto prim_node = node_input->cast<ValueNodePtr>();
1271   if (prim_node == nullptr) {
1272     MS_LOG(ERROR) << "The node's input is with invalid type.";
1273     return nullptr;
1274   }
1275   auto node_value = prim_node->value();
1276   if (node_value == nullptr) {
1277     MS_LOG(ERROR) << "The node's value is nullptr.";
1278     return nullptr;
1279   }
1280   auto prim = node_value->cast<PrimitivePtr>();
1281   if (prim == nullptr) {
1282     MS_LOG(ERROR) << "The node's value is with invalid type.";
1283     return nullptr;
1284   }
1285   return prim;
1286 }
1287 
MSOpSetAttrScalarFloat32(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,float value)1288 STATUS MSOpSetAttrScalarFloat32(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, float value) {
1289   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1290     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1291     return RET_NULL_PTR;
1292   }
1293   auto prim = GetOpPrim(res_mgr, op);
1294   if (prim == nullptr) {
1295     MS_LOG(ERROR) << "Get primitive node failed";
1296     return RET_NULL_PTR;
1297   }
1298   prim->set_attr(attr_name, mindspore::MakeValue(value));
1299   return RET_OK;
1300 }
1301 
MSOpSetAttrScalarBool(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,bool value)1302 STATUS MSOpSetAttrScalarBool(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, bool value) {
1303   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1304     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1305     return RET_NULL_PTR;
1306   }
1307   auto prim = GetOpPrim(res_mgr, op);
1308   if (prim == nullptr) {
1309     MS_LOG(ERROR) << "Get primitive node failed";
1310     return RET_NULL_PTR;
1311   }
1312   prim->set_attr(attr_name, mindspore::MakeValue(value));
1313   return RET_OK;
1314 }
1315 
MSOpSetAttrScalarInt32(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,int32_t value)1316 STATUS MSOpSetAttrScalarInt32(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int32_t value) {
1317   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1318     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1319     return RET_NULL_PTR;
1320   }
1321   auto prim = GetOpPrim(res_mgr, op);
1322   if (prim == nullptr) {
1323     MS_LOG(ERROR) << "Get primitive node failed";
1324     return RET_NULL_PTR;
1325   }
1326   prim->set_attr(attr_name, mindspore::MakeValue(value));
1327   return RET_OK;
1328 }
1329 
MSOpSetAttrScalarInt64(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,int64_t value)1330 STATUS MSOpSetAttrScalarInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t value) {
1331   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1332     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1333     return RET_NULL_PTR;
1334   }
1335   auto prim = GetOpPrim(res_mgr, op);
1336   if (prim == nullptr) {
1337     MS_LOG(ERROR) << "Get primitive node failed";
1338     return RET_NULL_PTR;
1339   }
1340   prim->set_attr(attr_name, mindspore::MakeValue(value));
1341   return RET_OK;
1342 }
1343 
MSOpSetAttrType(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,DataTypeC value)1344 STATUS MSOpSetAttrType(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value) {
1345   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1346     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1347     return RET_NULL_PTR;
1348   }
1349   auto prim = GetOpPrim(res_mgr, op);
1350   if (prim == nullptr) {
1351     MS_LOG(ERROR) << "Get primitive node failed";
1352     return RET_NULL_PTR;
1353   }
1354   auto cxx_type = mindspore::TypeId(value);
1355   prim->set_attr(attr_name, mindspore::TypeIdToType(cxx_type));
1356   return RET_OK;
1357 }
1358 
MSOpSetAttrTypeArray(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,DataTypeC value[],size_t vec_size)1359 STATUS MSOpSetAttrTypeArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, DataTypeC value[],
1360                             size_t vec_size) {
1361   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1362     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1363     return RET_NULL_PTR;
1364   }
1365   auto prim = GetOpPrim(res_mgr, op);
1366   if (prim == nullptr) {
1367     MS_LOG(ERROR) << "Get primitive node failed";
1368     return RET_NULL_PTR;
1369   }
1370   std::vector<mindspore::ValuePtr> vec_value;
1371   mindspore::TypeId cxx_type;
1372   for (size_t i = 0; i < vec_size; i++) {
1373     cxx_type = mindspore::TypeId(value[i]);
1374     vec_value.push_back(mindspore::TypeIdToType(cxx_type));
1375   }
1376   prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1377   return RET_OK;
1378 }
1379 
MSOpSetAttrArray(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,void * value,size_t vec_size,DataTypeC data_type)1380 STATUS MSOpSetAttrArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, void *value, size_t vec_size,
1381                         DataTypeC data_type) {
1382   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr || value == nullptr) {
1383     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] or [value_vec] is nullptr.";
1384     return RET_NULL_PTR;
1385   }
1386   auto prim = GetOpPrim(res_mgr, op);
1387   if (prim == nullptr) {
1388     MS_LOG(ERROR) << "Get primitive node failed";
1389     return RET_NULL_PTR;
1390   }
1391 
1392   switch (data_type) {
1393     case MS_BOOL: {
1394       std::vector<bool> vec_value(static_cast<bool *>(value), static_cast<bool *>(value) + vec_size);
1395       prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1396       break;
1397     }
1398     case MS_INT32: {
1399       std::vector<int32_t> vec_value(static_cast<int32_t *>(value), static_cast<int32_t *>(value) + vec_size);
1400       prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1401       break;
1402     }
1403     case MS_INT64: {
1404       std::vector<int64_t> vec_value(static_cast<int64_t *>(value), static_cast<int64_t *>(value) + vec_size);
1405       prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1406       break;
1407     }
1408     case MS_FLOAT32: {
1409       std::vector<float> vec_value(static_cast<float *>(value), static_cast<float *>(value) + vec_size);
1410       prim->set_attr(attr_name, mindspore::MakeValue(vec_value));
1411       break;
1412     }
1413     default:
1414       MS_LOG(ERROR) << "Unrecognized datatype w/ DataTypeC ID: " << data_type << " , Attribute name: " << attr_name
1415                     << std::endl;
1416       return RET_ERROR;
1417   }
1418   return RET_OK;
1419 }
1420 
MSOpSetAttrStringArray(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,const char * value[],size_t vec_size)1421 STATUS MSOpSetAttrStringArray(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, const char *value[],
1422                               size_t vec_size) {
1423   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr || value == nullptr) {
1424     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] or [value_vec] is nullptr.";
1425     return RET_NULL_PTR;
1426   }
1427   auto prim = GetOpPrim(res_mgr, op);
1428   if (prim == nullptr) {
1429     MS_LOG(ERROR) << "Get primitive node failed";
1430     return RET_NULL_PTR;
1431   }
1432 
1433   std::vector<mindspore::ValuePtr> vec_value;
1434   for (size_t i = 0; i < vec_size; i++) {
1435     vec_value.push_back(mindspore::MakeValue(value[i]));
1436   }
1437   prim->set_attr(attr_name, std::make_shared<mindspore::ValueList>(vec_value));
1438   return RET_OK;
1439 }
1440 
MSOpSetAttrString(ResMgrHandle res_mgr,NodeHandle op,const char * attr_name,const char * value)1441 STATUS MSOpSetAttrString(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, const char *value) {
1442   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr || value == nullptr) {
1443     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] or [value_vec] is nullptr.";
1444     return RET_NULL_PTR;
1445   }
1446   auto prim = GetOpPrim(res_mgr, op);
1447   if (prim == nullptr) {
1448     MS_LOG(ERROR) << "Get primitive node failed";
1449     return RET_NULL_PTR;
1450   }
1451   std::string value_str(value);
1452   prim->set_attr(attr_name, mindspore::MakeValue(value_str));
1453   return RET_OK;
1454 }
1455 
MSOpGetAttrScalarInt64(ResMgrHandle res_mgr,ConstNodeHandle op,const char * attr_name,STATUS * error)1456 int64_t MSOpGetAttrScalarInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, STATUS *error) {
1457   if (error == nullptr) {
1458     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
1459     return 0;
1460   }
1461   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1462     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1463     *error = RET_NULL_PTR;
1464     return 0;
1465   }
1466   std::string attr_name_str(attr_name);
1467   try {
1468     auto prim = GetOpPrim(res_mgr, op);
1469     MS_EXCEPTION_IF_NULL(prim);
1470     auto value = prim->GetAttr(attr_name_str);
1471     auto value_int64 = value->cast<Int64ImmPtr>();
1472     MS_EXCEPTION_IF_NULL(value_int64);
1473     auto ret_val = value_int64->value();
1474     *error = RET_OK;
1475     return ret_val;
1476   } catch (const std::exception &e) {
1477     MS_LOG(ERROR) << " Get Attribute failed. Error info: " << e.what();
1478     *error = RET_ERROR;
1479     return 0;
1480   }
1481 }
1482 
MSOpGetAttrArrayInt64(ResMgrHandle res_mgr,ConstNodeHandle op,const char * attr_name,int64_t values[],size_t value_num)1483 STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, int64_t values[],
1484                              size_t value_num) {
1485   if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
1486     MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
1487     return RET_NULL_PTR;
1488   }
1489   std::string attr_name_str(attr_name);
1490   try {
1491     auto prim = GetOpPrim(res_mgr, op);
1492     MS_EXCEPTION_IF_NULL(prim);
1493     auto value = prim->GetAttr(attr_name_str);
1494     MS_EXCEPTION_IF_NULL(value);
1495     auto value_tuple = value->cast<ValueTuplePtr>();
1496     MS_EXCEPTION_IF_NULL(value_tuple);
1497     auto value_list = value_tuple->value();
1498     if (value_list.size() != value_num) {
1499       MS_LOG(ERROR) << "Invalid input vector length, it should be: " << value_list.size() << ", but got: " << value_num;
1500       return RET_ERROR;
1501     }
1502     for (size_t i = 0; i < value_num; i++) {
1503       auto val_imm = value_list[i]->cast<Int64ImmPtr>();
1504       values[i] = val_imm->value();
1505     }
1506     return RET_OK;
1507   } catch (const std::exception &e) {
1508     MS_LOG(ERROR) << "Get Attribute failed. Error info: " << e.what();
1509     return RET_ERROR;
1510   }
1511 }
1512 
MSOpSetName(ResMgrHandle res_mgr,NodeHandle node,const char * name)1513 STATUS MSOpSetName(ResMgrHandle res_mgr, NodeHandle node, const char *name) {
1514   if (res_mgr == nullptr || node == nullptr || name == nullptr) {
1515     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [name] is nullptr.";
1516     return RET_NULL_PTR;
1517   }
1518   auto node_impl = GetSrcPtr<CNodePtr>(res_mgr, node);
1519   if (node_impl == nullptr) {
1520     MS_LOG(ERROR) << "Get source pointer failed. Please check whether the input node is an operator node.";
1521     return RET_ERROR;
1522   }
1523   node_impl->set_fullname_with_scope(name);
1524   return RET_OK;
1525 }
1526 
MSNodeGetName(ResMgrHandle res_mgr,ConstNodeHandle node,char str_buf[],size_t str_len)1527 STATUS MSNodeGetName(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
1528   if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
1529     MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
1530     return RET_NULL_PTR;
1531   }
1532   auto node_impl = GetSrcPtr<AnfNodePtr>(res_mgr, node);
1533   if (node_impl == nullptr) {
1534     MS_LOG(ERROR) << "Get source pointer failed.";
1535     return RET_ERROR;
1536   }
1537   auto name = node_impl->fullname_with_scope();
1538   size_t valid_size = name.size() < str_len - 1 ? name.size() : str_len - 1;
1539   for (size_t i = 0; i < valid_size; i++) {
1540     str_buf[i] = name.c_str()[i];
1541   }
1542   str_buf[valid_size] = '\0';
1543   return RET_OK;
1544 }
1545 
1546 // dynamic op / eager mode
GenerateInnerInfo(ResMgrHandle res_mgr,const char * op_type,TensorHandle const inputs[],size_t input_num,size_t output_num,const DynamicOpInfo & extra_info)1547 std::shared_ptr<InnerOpInfo> GenerateInnerInfo(ResMgrHandle res_mgr, const char *op_type, TensorHandle const inputs[],
1548                                                size_t input_num, size_t output_num, const DynamicOpInfo &extra_info) {
1549   MS_EXCEPTION_IF_NULL(op_type);
1550   MS_EXCEPTION_IF_NULL(inputs);
1551   std::vector<ValuePtr> src_inputs{};
1552   std::vector<ShapeVector> out_shapes{};
1553   std::vector<DataTypeC> out_dtypes{};
1554   std::vector<std::pair<std::string, ValuePtr>> attrs_pair{};
1555   for (size_t i = 0; i < input_num; i++) {
1556     auto input = GetSrcPtr<ValuePtr>(res_mgr, inputs[i]);
1557     if (input == nullptr) {
1558       MS_LOG(EXCEPTION) << "Invalid input. Index: " << i;
1559     }
1560     (void)src_inputs.emplace_back(input);
1561   }
1562   if (extra_info.output_shapes != nullptr && extra_info.output_dtypes != nullptr) {
1563     for (size_t i = 0; i < output_num; i++) {
1564       MS_EXCEPTION_IF_NULL(extra_info.output_dims);
1565       size_t dim = extra_info.output_dims[i];
1566       ShapeVector out_shape{};
1567       MS_EXCEPTION_IF_NULL(extra_info.output_shapes[i]);
1568       for (size_t j = 0; j < dim; j++) {
1569         (void)out_shape.emplace_back(extra_info.output_shapes[i][j]);
1570       }
1571       (void)out_shapes.emplace_back(out_shape);
1572       (void)out_dtypes.emplace_back(extra_info.output_dtypes[i]);
1573     }
1574   }
1575   for (size_t i = 0; i < extra_info.attr_num; i++) {
1576     MS_EXCEPTION_IF_NULL(extra_info.attr_names[i]);
1577     auto value = GetSrcPtr<ValuePtr>(res_mgr, extra_info.attr_values[i]);
1578     if (value == nullptr) {
1579       MS_LOG(ERROR) << "Get attribute's source pointer failed, attribute index: " << i;
1580     }
1581     (void)attrs_pair.emplace_back(std::make_pair(extra_info.attr_names[i], value));
1582   }
1583   return std::make_shared<InnerOpInfo>(op_type, src_inputs, out_shapes, out_dtypes, attrs_pair);
1584 }
1585 
CheckExtraInfo(const DynamicOpInfo & extra_info)1586 STATUS CheckExtraInfo(const DynamicOpInfo &extra_info) {
1587   MS_ERROR_IF_TRUE_W_RET_N_LOG(extra_info.attr_num < 0, RET_ERROR, "The attr_num must be non-zero!");
1588   MS_ERROR_IF_TRUE_W_RET_N_LOG(
1589     extra_info.attr_num == 0 && (extra_info.attr_names != nullptr || extra_info.attr_values != nullptr), RET_ERROR,
1590     "The attr_name and attr_values must be nullptr if attr_num is 0!");
1591   MS_ERROR_IF_TRUE_W_RET_N_LOG(
1592     extra_info.attr_num != 0 && (extra_info.attr_names == nullptr || extra_info.attr_values == nullptr), RET_ERROR,
1593     "The attr_name and attr_values must be specified if attr_num is non-negative!");
1594   MS_ERROR_IF_TRUE_W_RET_N_LOG(extra_info.output_dims != nullptr && extra_info.output_shapes == nullptr, RET_ERROR,
1595                                "The output_shapes must be not nullptr if output_dims is non-zero!");
1596   return RET_OK;
1597 }
1598 
OpRunInfoSetInputs(ResMgrHandle res_mgr,TensorHandle const inputs[],size_t input_num,FrontendOpRunInfoPtr op_run_info)1599 STATUS OpRunInfoSetInputs(ResMgrHandle res_mgr, TensorHandle const inputs[], size_t input_num,
1600                           FrontendOpRunInfoPtr op_run_info) {
1601   auto prim = op_run_info->op_grad_info->op_prim;
1602   MS_EXCEPTION_IF_NULL(prim);
1603   op_run_info->input_size = input_num;
1604   op_run_info->op_grad_info->input_value.resize(input_num);
1605   for (size_t i = 0; i < input_num; i++) {
1606     auto in_arg = GetSrcPtr<ValuePtr>(res_mgr, inputs[i]);
1607     if (in_arg == nullptr) {
1608       MS_LOG(ERROR) << "Invalid input. Index: " << i;
1609       return RET_ERROR;
1610     }
1611     op_run_info->op_grad_info->input_value[i] = in_arg;
1612   }
1613   return RET_OK;
1614 }
1615 
DynamicOpInfer(size_t output_num,FrontendOpRunInfoPtr op_run_info,const DynamicOpInfo & extra_info)1616 STATUS DynamicOpInfer(size_t output_num, FrontendOpRunInfoPtr op_run_info, const DynamicOpInfo &extra_info) {
1617   MS_EXCEPTION_IF_NULL(op_run_info);
1618   // get abstract
1619   op_run_info->op_grad_info->input_abs.resize(op_run_info->input_size);
1620   for (size_t i = 0; i < op_run_info->input_size; ++i) {
1621     auto input_value = op_run_info->op_grad_info->input_value[i];
1622     op_run_info->op_grad_info->input_abs[i] = input_value->ToAbstract();
1623   }
1624   // do infer
1625   AbstractBasePtr out_abs = nullptr;
1626   auto prim = op_run_info->op_grad_info->op_prim;
1627   if (extra_info.output_shapes != nullptr && extra_info.output_dims != nullptr && extra_info.output_dtypes != nullptr) {
1628     auto shape = BuildShape(extra_info.output_shapes, extra_info.output_dims, output_num);
1629     auto type = BuildType(extra_info.output_dtypes, output_num);
1630     out_abs = BuildAbstract(shape, type);
1631   } else {
1632     MS_LOG(INFO) << "Output shapes and dtypes info is not specified completely, using inner infer.";
1633     prim->BeginRecordAddAttr();
1634     out_abs = OpInferShapeAndType(prim, op_run_info->op_grad_info->input_abs);
1635     prim->EndRecordAddAttr();
1636   }
1637   MS_EXCEPTION_IF_NULL(out_abs);
1638   op_run_info->base_op_run_info.abstract = out_abs;
1639   return RET_OK;
1640 }
1641 
DynamicOpGetMindRTBackend(ResMgrHandle res_mgr,const string & cur_device_target,uint32_t device_id)1642 MindRTBackendPtr DynamicOpGetMindRTBackend(ResMgrHandle res_mgr, const string &cur_device_target, uint32_t device_id) {
1643   auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
1644   auto cached_backend = res_mgr_ptr->GetBackendFromCache(cur_device_target);
1645   if (cached_backend != nullptr) {
1646     return cached_backend;
1647   } else {
1648     std::lock_guard<std::mutex> guard(mindspore::pipeline::Resource::GetBackendInitMutex());
1649     auto backend = std::make_shared<mindspore::compile::MindRTBackend>("ms", cur_device_target, device_id);
1650     MS_EXCEPTION_IF_NULL(backend);
1651     res_mgr_ptr->CacheBackend(cur_device_target, backend);
1652     return backend;
1653   }
1654 }
1655 
DynamicOpRun(ResMgrHandle res_mgr,const FrontendOpRunInfoPtr & op_run_info)1656 ValuePtr DynamicOpRun(ResMgrHandle res_mgr, const FrontendOpRunInfoPtr &op_run_info) {
1657   MS_LOG(DEBUG) << "DynamicOpRun start";
1658   MS_EXCEPTION_IF_NULL(op_run_info);
1659   auto ms_context = mindspore::MsContext::GetInstance();
1660   MS_EXCEPTION_IF_NULL(ms_context);
1661   auto device_id = ms_context->get_param<uint32_t>(mindspore::MS_CTX_DEVICE_ID);
1662   ms_context->set_param<bool>(mindspore::MS_CTX_ENABLE_PYNATIVE_INFER, true);
1663   mindspore::pynative::PyNativeAlgo::DataConvert::GetInputTensor(op_run_info, nullptr);
1664   auto backend_op_run_info = std::make_shared<mindspore::BackendOpRunInfo>(
1665     op_run_info->base_op_run_info, std::make_shared<mindspore::Primitive>(*op_run_info->op_grad_info->op_prim), true,
1666     false);
1667 
1668   mindspore::VectorRef outputs;
1669   const auto &cur_mindrt_backend =
1670     DynamicOpGetMindRTBackend(res_mgr, op_run_info->base_op_run_info.device_target, device_id);
1671   MS_EXCEPTION_IF_NULL(cur_mindrt_backend);
1672   py::scoped_interpreter py_scope;
1673   if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
1674     mindspore::AnfAlgo::SetDynamicAttrToPrim(backend_op_run_info->op_prim);
1675     cur_mindrt_backend->RunOpDynamic(backend_op_run_info, &outputs);
1676   } else {
1677     cur_mindrt_backend->RunOp(backend_op_run_info, &outputs);
1678   }
1679 
1680   if (op_run_info->base_op_run_info.has_dynamic_output) {
1681     op_run_info->base_op_run_info.abstract = backend_op_run_info->base_op_run_info.abstract;
1682   }
1683   bool is_out_sequence = (op_run_info->base_op_run_info.abstract == nullptr ||
1684                           op_run_info->base_op_run_info.abstract->isa<mindspore::abstract::AbstractSequence>());
1685   const auto &result = mindspore::pynative::PyNativeAlgo::DataConvert::VectorRefToValue(
1686     outputs, op_run_info->requires_grad, is_out_sequence);
1687   ms_context->set_param<bool>(mindspore::MS_CTX_ENABLE_PYNATIVE_INFER, false);
1688   MS_LOG(DEBUG) << "DynamicOpRun end";
1689   return result;
1690 }
1691 
MSRunOpWithInfo(ResMgrHandle res_mgr,const char * op_type,TensorHandle const inputs[],size_t input_num,TensorHandle outputs[],size_t output_num,DynamicOpInfo extra_info)1692 STATUS MSRunOpWithInfo(ResMgrHandle res_mgr, const char *op_type, TensorHandle const inputs[], size_t input_num,
1693                        TensorHandle outputs[], size_t output_num, DynamicOpInfo extra_info) {
1694   MS_ERROR_IF_TRUE_W_RET_N_LOG(res_mgr == nullptr, RET_NULL_PTR, "Input Handle [res_mgr] is nullptr!");
1695   MS_ERROR_IF_TRUE_W_RET_N_LOG(inputs == nullptr, RET_NULL_PTR, "Input Handle [inputs] is nullptr!");
1696   MS_ERROR_IF_TRUE_W_RET_N_LOG(outputs == nullptr, RET_NULL_PTR, "Input Handle [outputs] is nullptr!");
1697   MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num == 0, RET_NULL_PTR, "Input [input_num] must be non-zero!");
1698   MS_ERROR_IF_TRUE_W_RET_N_LOG(output_num == 0, RET_NULL_PTR, "Input [output_num] must be non-zero!");
1699   MS_ERROR_IF_TRUE_W_RET_N_LOG(CheckExtraInfo(extra_info) != RET_OK, RET_NULL_PTR, "Input [extra_info] is invalid!");
1700   try {
1701     auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
1702     FrontendOpRunInfoPtr op_run_info = nullptr;
1703     auto op_info = GenerateInnerInfo(res_mgr, op_type, inputs, input_num, output_num, extra_info);
1704     auto cached_run_info = res_mgr_ptr->GetOpRunInfoFromCache(op_info);
1705     if (cached_run_info != nullptr) {
1706       op_run_info = cached_run_info;
1707       // set inputs
1708       auto ret = OpRunInfoSetInputs(res_mgr, inputs, input_num, op_run_info);
1709       if (ret != RET_OK) {
1710         MS_LOG(ERROR) << "Dynamic Op set inputs failed.";
1711         return RET_ERROR;
1712       }
1713     } else {
1714       // create op_run_info
1715       op_run_info = std::make_shared<mindspore::pynative::FrontendOpRunInfo>();
1716       op_run_info->base_op_run_info.op_name = op_type;
1717       op_run_info->requires_grad = false;
1718       auto ms_context = mindspore::MsContext::GetInstance();
1719       auto cur_target = ms_context->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
1720       op_run_info->base_op_run_info.device_target = cur_target;
1721       // create prim
1722       auto prim = std::make_shared<PrimitiveImpl>(op_type);
1723       op_run_info->op_grad_info->op_prim = prim;
1724       // set inputs
1725       bool is_dynamic_shape =
1726         op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.use_dynamic_shape_process;
1727       mindspore::pynative::PyNativeAlgo::Common::GetConstInputToAttr(prim, op_type, cur_target, is_dynamic_shape,
1728                                                                      &op_run_info->input_to_attr);
1729       auto ret = OpRunInfoSetInputs(res_mgr, inputs, input_num, op_run_info);
1730       if (ret != RET_OK) {
1731         MS_LOG(ERROR) << "Dynamic Op set inputs failed.";
1732         return RET_ERROR;
1733       }
1734       // set args
1735       if (extra_info.attr_names != nullptr && extra_info.attr_values != nullptr) {
1736         ret = OpSetAttrs(res_mgr, prim, extra_info.attr_names, extra_info.attr_values, extra_info.attr_num);
1737         if (ret != RET_OK) {
1738           MS_LOG(ERROR) << "Dynamic Op set attributes failed.";
1739           return RET_ERROR;
1740         }
1741       }
1742       // infer and set abstract
1743       ret = DynamicOpInfer(output_num, op_run_info, extra_info);
1744       if (ret != RET_OK) {
1745         MS_LOG(ERROR) << "Dynamic Op infer shape and type failed.";
1746         return RET_ERROR;
1747       }
1748       // cache op run info
1749       res_mgr_ptr->CacheOpRunInfo(op_info, op_run_info);
1750     }
1751 
1752     // run op
1753     op_run_info->real_out = DynamicOpRun(res_mgr, op_run_info);
1754     if (op_run_info->real_out->isa<ValueSequenceImpl>()) {
1755       const auto &result_v_list = op_run_info->real_out->cast<ValueSequencePtr>();
1756       if (result_v_list->size() == 1 && op_run_info->base_op_run_info.abstract != nullptr &&
1757           !op_run_info->base_op_run_info.abstract->isa<mindspore::abstract::AbstractSequence>()) {
1758         op_run_info->real_out = result_v_list->value().front();
1759       }
1760     }
1761 
1762     // clear used input tensor
1763     op_run_info->base_op_run_info.expanded_input_values.clear();
1764     op_run_info->base_op_run_info.input_types.clear();
1765 
1766     // get output tensor
1767     const std::vector<TensorPtr> &ref_outputs = ConvertOutputToTensor(op_run_info->real_out);
1768     if (ref_outputs.size() != output_num) {
1769       MS_LOG(ERROR) << "Invalid outputs number, it should be: " << ref_outputs.size() << ", but got: " << output_num;
1770       return RET_ERROR;
1771     }
1772     for (size_t i = 0; i < output_num; i++) {
1773       outputs[i] = GetRawPtr(res_mgr, ref_outputs[i]);
1774     }
1775   } catch (const std::exception &e) {
1776     MS_LOG(ERROR) << "Run op failed. Error info: " << e.what();
1777     return RET_ERROR;
1778   }
1779   return RET_OK;
1780 }
1781 
MSRunOp(ResMgrHandle res_mgr,const char * op_type,TensorHandle const inputs[],size_t input_num,TensorHandle outputs[],size_t output_num)1782 STATUS MSRunOp(ResMgrHandle res_mgr, const char *op_type, TensorHandle const inputs[], size_t input_num,
1783                TensorHandle outputs[], size_t output_num) {
1784   MS_ERROR_IF_TRUE_W_RET_N_LOG(res_mgr == nullptr, RET_NULL_PTR, "Input Handle [res_mgr] is nullptr!");
1785   MS_ERROR_IF_TRUE_W_RET_N_LOG(inputs == nullptr, RET_NULL_PTR, "Input Handle [inputs] is nullptr!");
1786   MS_ERROR_IF_TRUE_W_RET_N_LOG(outputs == nullptr, RET_NULL_PTR, "Input Handle [outputs] is nullptr!");
1787   MS_ERROR_IF_TRUE_W_RET_N_LOG(input_num == 0, RET_NULL_PTR, "Input [input_num] must be non-zero!");
1788   MS_ERROR_IF_TRUE_W_RET_N_LOG(output_num == 0, RET_NULL_PTR, "Input [output_num] must be non-zero!");
1789   DynamicOpInfo extra_info = {NULL, NULL, 0, NULL, NULL, NULL};
1790   return MSRunOpWithInfo(res_mgr, op_type, inputs, input_num, outputs, output_num, extra_info);
1791 }
1792