1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/c_api/ms/graph.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "c_api/src/helper.h"
20 #include "c_api/src/common.h"
21 #include "c_api/src/utils.h"
22 #include "c_api/src/pass.h"
23 #include "base/base.h"
24 #include "ir/func_graph.h"
25 #include "ir/anf.h"
26 #include "ir/func_graph_cloner.h"
27 #include "utils/ms_context.h"
28 #include "backend/graph_compiler/backend.h"
29 #include "pipeline/jit/ps/pass.h"
30 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
31
MSFuncGraphCreate(ResMgrHandle res_mgr)32 GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr) {
33 if (res_mgr == nullptr) {
34 MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
35 return nullptr;
36 }
37 auto fg = std::make_shared<FuncGraphImpl>();
38 return GetRawPtr(res_mgr, fg);
39 }
40
MSFuncGraphLoad(ResMgrHandle res_mgr,const char * file_path)41 GraphHandle MSFuncGraphLoad(ResMgrHandle res_mgr, const char *file_path) {
42 if (res_mgr == nullptr) {
43 MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
44 return nullptr;
45 }
46 try {
47 mindspore::MindIRLoader mind_loader(false, nullptr, 0, "AES-GCM", false);
48 auto fg = mind_loader.LoadMindIR(file_path);
49 if (fg == nullptr) {
50 MS_LOG(ERROR) << "Load funcgraph from MINDIR fail.";
51 }
52 return GetRawPtr(res_mgr, fg);
53 } catch (const std::exception &e) {
54 MS_LOG(ERROR) << "FuncGraph load failed. Error info: " << e.what();
55 return nullptr;
56 }
57 }
58
MSFuncGraphGetInput(ResMgrHandle res_mgr,ConstGraphHandle graph,size_t i)59 NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
60 if (res_mgr == nullptr || graph == nullptr) {
61 MS_LOG(ERROR) << "Input Handle [res_mgr] or [cnode] is nullptr.";
62 return nullptr;
63 }
64 try {
65 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
66 MS_EXCEPTION_IF_NULL(res_fg);
67 auto fg_inputs = res_fg->get_inputs();
68 if (i >= fg_inputs.size()) {
69 MS_LOG(ERROR) << "Invalid input index, it should be less than " << fg_inputs.size() << ", but got: " << i;
70 return nullptr;
71 }
72 return GetRawPtr(res_mgr, fg_inputs[i]);
73 } catch (const std::exception &e) {
74 MS_LOG(ERROR) << "FuncGraph get inputs failed. Error info: " << e.what();
75 return nullptr;
76 }
77 }
78
MSFuncGraphGetInputNum(ResMgrHandle res_mgr,ConstGraphHandle graph,STATUS * error)79 size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
80 if (error == nullptr) {
81 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
82 return 0;
83 }
84 if (res_mgr == nullptr || graph == nullptr) {
85 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
86 *error = RET_NULL_PTR;
87 return 0;
88 }
89 size_t input_num;
90 try {
91 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
92 MS_EXCEPTION_IF_NULL(res_fg);
93 input_num = res_fg->get_inputs().size();
94 } catch (const std::exception &e) {
95 MS_LOG(ERROR) << "FuncGraph get input number failed. Error info: " << e.what();
96 *error = RET_ERROR;
97 return 0;
98 }
99 *error = RET_OK;
100 return input_num;
101 }
102
MSFuncGraphGetInputs(ResMgrHandle res_mgr,ConstGraphHandle graph,NodeHandle inputs[],size_t input_num)103 STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle inputs[], size_t input_num) {
104 if (res_mgr == nullptr || graph == nullptr || inputs == nullptr) {
105 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
106 return RET_NULL_PTR;
107 }
108 try {
109 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
110 MS_EXCEPTION_IF_NULL(res_fg);
111 auto fg_inputs = res_fg->get_inputs();
112 if (fg_inputs.size() != input_num) {
113 MS_LOG(ERROR) << "Invalid input number, it should be: " << fg_inputs.size() << ", but got: " << input_num;
114 return RET_ERROR;
115 }
116 for (size_t i = 0; i < input_num; i++) {
117 inputs[i] = GetRawPtr(res_mgr, fg_inputs[i]);
118 }
119 } catch (const std::exception &e) {
120 MS_LOG(ERROR) << "FuncGraph get inputs failed. Error info: " << e.what();
121 return RET_ERROR;
122 }
123 return RET_OK;
124 }
125
MSFuncGraphSetOutput(ResMgrHandle res_mgr,GraphHandle graph,ConstNodeHandle op_node,bool force_new_ret)126 STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op_node, bool force_new_ret) {
127 if (res_mgr == nullptr || graph == nullptr || op_node == nullptr) {
128 MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [op_node] is nullptr.";
129 return RET_NULL_PTR;
130 }
131 try {
132 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
133 MS_EXCEPTION_IF_NULL(res_fg);
134 auto res_anfnode = GetSrcPtr<AnfNodePtr>(res_mgr, op_node);
135 MS_EXCEPTION_IF_NULL(res_anfnode);
136 res_fg->set_output(res_anfnode, force_new_ret);
137 } catch (const std::exception &e) {
138 MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
139 return RET_ERROR;
140 }
141 return RET_OK;
142 }
143
MSFuncGraphSetOutputs(ResMgrHandle res_mgr,GraphHandle graph,Handle const outputs[],size_t output_num,bool force_new_ret)144 STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, Handle const outputs[], size_t output_num,
145 bool force_new_ret) {
146 if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
147 MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [outputs] is nullptr.";
148 return RET_NULL_PTR;
149 }
150 try {
151 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
152 MS_EXCEPTION_IF_NULL(res_fg);
153 std::vector<AnfNodePtr> out_nodes{NewValueNode(mindspore::prim::kPrimMakeTuple)};
154 mindspore::AbstractBasePtrList abs_list{};
155 for (size_t i = 0; i < output_num; ++i) {
156 auto out_node = GetSrcPtr<AnfNodePtr>(res_mgr, outputs[i]);
157 MS_EXCEPTION_IF_NULL(out_node);
158 out_nodes.push_back(out_node);
159 ConvertConstScalarInputToTensor(out_node);
160 abs_list.push_back(out_node->abstract());
161 }
162 auto make_tuple_cnode = res_fg->NewCNodeInOrder(out_nodes);
163 make_tuple_cnode->set_abstract(std::make_shared<AbstractTupleImpl>(abs_list));
164 res_fg->set_output(make_tuple_cnode, force_new_ret);
165 } catch (const std::exception &e) {
166 MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
167 return RET_ERROR;
168 }
169 return RET_OK;
170 }
171
MSFuncGraphGetOutput(ResMgrHandle res_mgr,ConstGraphHandle graph,size_t i)172 NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
173 if (res_mgr == nullptr || graph == nullptr) {
174 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
175 return nullptr;
176 }
177 try {
178 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
179 MS_EXCEPTION_IF_NULL(res_fg);
180 auto out_node = res_fg->output();
181 if (IsPrimitiveCNode(out_node, mindspore::prim::kPrimMakeTuple)) {
182 auto out_cnode = out_node->cast<CNodePtr>();
183 MS_EXCEPTION_IF_NULL(out_cnode);
184 auto out_num = out_cnode->size() - 1;
185 if (i >= out_num) {
186 MS_LOG(ERROR) << "Invalid output index, it should be less than " << out_num << ", but got: " << i;
187 return nullptr;
188 }
189 return GetRawPtr(res_mgr, out_cnode->input(i + 1));
190 } else {
191 if (i >= 1) {
192 MS_LOG(ERROR)
193 << "Invalid output index. The graph has only one output, so the output index should be 0, but got: " << i;
194 return nullptr;
195 }
196 return GetRawPtr(res_mgr, out_node);
197 }
198 } catch (const std::exception &e) {
199 MS_LOG(ERROR) << "FuncGraph get output failed. Error info: " << e.what();
200 return nullptr;
201 }
202 }
203
MSFuncGraphGetOutputNum(ResMgrHandle res_mgr,ConstGraphHandle graph,STATUS * error)204 size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
205 if (error == nullptr) {
206 MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
207 return 0;
208 }
209 if (res_mgr == nullptr || graph == nullptr) {
210 MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [error] is nullptr.";
211 *error = RET_NULL_PTR;
212 return 0;
213 }
214 size_t out_num = 0;
215 try {
216 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
217 MS_EXCEPTION_IF_NULL(res_fg);
218 auto out_node = res_fg->output();
219 if (IsPrimitiveCNode(out_node, mindspore::prim::kPrimMakeTuple)) {
220 auto out_cnode = out_node->cast<CNodePtr>();
221 MS_EXCEPTION_IF_NULL(out_cnode);
222 out_num = out_cnode->size() - 1;
223 } else {
224 out_num = 1;
225 }
226 } catch (const std::exception &e) {
227 MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
228 *error = RET_ERROR;
229 return 0;
230 }
231 return out_num;
232 }
233
MSFuncGraphGetOutputs(ResMgrHandle res_mgr,ConstGraphHandle graph,NodeHandle outputs[],size_t output_num)234 STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle outputs[], size_t output_num) {
235 if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
236 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
237 return RET_NULL_PTR;
238 }
239 try {
240 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
241 MS_EXCEPTION_IF_NULL(res_fg);
242 size_t out_num = 0;
243 auto out_node = res_fg->output();
244 auto out_cnode = out_node->cast<CNodePtr>();
245 MS_EXCEPTION_IF_NULL(out_cnode);
246 if (IsPrimitiveCNode(out_node, mindspore::prim::kPrimMakeTuple)) {
247 out_num = out_cnode->size() - 1;
248 } else {
249 out_num = 1;
250 }
251 if (out_num != output_num) {
252 MS_LOG(ERROR) << "Invalid output number, it should be: " << out_num << ", but got: " << output_num;
253 return RET_ERROR;
254 }
255 for (size_t i = 0; i < output_num; i++) {
256 outputs[i] = GetRawPtr(res_mgr, out_cnode->input(i + 1));
257 }
258 } catch (const std::exception &e) {
259 MS_LOG(ERROR) << "FuncGraph get inputs failed. Error info: " << e.what();
260 return RET_ERROR;
261 }
262 return RET_OK;
263 }
264
MSFuncGraphReplace(ResMgrHandle res_mgr,GraphHandle graph,ConstNodeHandle old_node,ConstNodeHandle new_node)265 STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle old_node, ConstNodeHandle new_node) {
266 if (res_mgr == nullptr || graph == nullptr || old_node == nullptr || new_node == nullptr) {
267 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [old_node] or [new_node] is nullptr.";
268 return RET_NULL_PTR;
269 }
270 try {
271 auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
272 MS_EXCEPTION_IF_NULL(res_fg);
273 auto manager_ptr = mindspore::Manage(res_fg, true);
274 MS_EXCEPTION_IF_NULL(manager_ptr);
275 auto res_old_anfnode = GetSrcPtr<AnfNodePtr>(res_mgr, old_node);
276 MS_EXCEPTION_IF_NULL(res_old_anfnode);
277 auto res_new_anfnode = GetSrcPtr<AnfNodePtr>(res_mgr, new_node);
278 MS_EXCEPTION_IF_NULL(res_new_anfnode);
279 (void)manager_ptr->Replace(res_old_anfnode, res_new_anfnode);
280 } catch (const std::exception &e) {
281 MS_LOG(ERROR) << "FuncGraph replace failed. Error info: " << e.what();
282 return RET_ERROR;
283 }
284 return RET_OK;
285 }
286
MSFuncGraphCompile(ResMgrHandle res_mgr,GraphHandle graph,OptPassID * opt_pass,size_t pass_num)287 STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph, OptPassID *opt_pass, size_t pass_num) {
288 if (res_mgr == nullptr || graph == nullptr) {
289 MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
290 return RET_NULL_PTR;
291 }
292 auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
293 auto context_ptr = mindspore::MsContext::GetInstance();
294 try {
295 auto func_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
296 MS_EXCEPTION_IF_NULL(func_graph);
297 auto fg_mgr = mindspore::MakeManager();
298 MS_EXCEPTION_IF_NULL(fg_mgr);
299 fg_mgr->AddFuncGraph(func_graph, true);
300 func_graph->set_manager(fg_mgr);
301 (void)mindspore::pipeline::AutoMonad(func_graph);
302 (void)mindspore::LiftingClone(func_graph);
303 context_ptr->Refresh();
304 std::string backend_name = context_ptr->backend_policy();
305 std::string target = context_ptr->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
306 uint32_t device_id = context_ptr->get_param<uint32_t>(mindspore::MS_CTX_DEVICE_ID);
307 auto backend = res_mgr_ptr->GetBackendFromCache(target);
308 if (backend == nullptr) {
309 backend = std::make_shared<mindspore::compile::MindRTBackend>(backend_name, target, device_id);
310 res_mgr_ptr->CacheBackend(target, backend);
311 }
312 if (target == mindspore::kAscendDevice &&
313 context_ptr->get_param<int>(mindspore::MS_CTX_EXECUTION_MODE) == mindspore::kPynativeMode) {
314 backend->set_is_multi_graph_sink(false);
315 }
316 for (size_t i = 0; i < pass_num; i++) {
317 auto iter = kPassEnumToFuncMap.find(opt_pass[i]);
318 if (iter == kPassEnumToFuncMap.end()) {
319 MS_LOG(ERROR) << "Unsupported optimization pass: " << opt_pass[i];
320 return RET_ERROR;
321 }
322 auto pass_func = iter->second;
323 auto success = pass_func(func_graph);
324 if (!success) {
325 MS_LOG(WARNING) << "Run optimization pass failed! Pass ID: " << opt_pass[i];
326 }
327 }
328 context_ptr->set_param<bool>(mindspore::MS_CTX_ENABLE_MINDRT, true);
329 auto actor_info = backend->CompileGraphs(func_graph);
330 res_mgr_ptr->SetResult(mindspore::pipeline::kOutput, actor_info);
331 } catch (const std::exception &e) {
332 MS_LOG(ERROR) << "FuncGraph compile failed. Error info: " << e.what();
333 return RET_ERROR;
334 }
335 return RET_OK;
336 }
337
MSFuncGraphRun(ResMgrHandle res_mgr,GraphHandle graph,Handle const inputs[],size_t input_num,TensorHandle outputs[],size_t outputs_num)338 STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, Handle const inputs[], size_t input_num,
339 TensorHandle outputs[], size_t outputs_num) {
340 if (res_mgr == nullptr || inputs == nullptr || outputs == nullptr) {
341 MS_LOG(ERROR) << "Input Handle [res_mgr] or [inputs] or [outputs] is nullptr.";
342 return RET_NULL_PTR;
343 }
344 mindspore::VectorRef args;
345 for (size_t i = 0; i < input_num; i++) {
346 auto in_arg = GetSrcPtr<ValuePtr>(res_mgr, inputs[i]);
347 if (in_arg == nullptr) {
348 MS_LOG(ERROR) << "Invalid input. Index: " << i;
349 return RET_NULL_PTR;
350 }
351 args.push_back(in_arg);
352 }
353 auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
354 auto raw_info = res_mgr_ptr->GetResult(mindspore::pipeline::kOutput);
355 mindspore::VectorRef out_vec;
356 try {
357 auto context_ptr = mindspore::MsContext::GetInstance();
358 std::string target = context_ptr->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
359 auto mindrt_bc_ptr = res_mgr_ptr->GetBackendFromCache(target);
360 MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
361 auto func_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
362 MS_EXCEPTION_IF_NULL(func_graph);
363 auto params_anf = func_graph->parameters();
364 for (auto p : params_anf) {
365 auto param = p->cast<ParameterPtr>();
366 if (param->has_default()) {
367 auto value_ptr = param->default_param();
368 auto tensor_ptr = value_ptr->cast<TensorPtr>();
369 args.push_back(tensor_ptr);
370 }
371 }
372 const auto actor_info = raw_info.cast<mindspore::compile::ActorInfo>();
373 mindrt_bc_ptr->RunGraph(actor_info, args, &out_vec);
374 const std::vector<TensorPtr> &ref_outputs = ConvertOutputToTensor(out_vec);
375 if (ref_outputs.size() != outputs_num) {
376 MS_LOG(ERROR) << "Invalid outputs number, it should be: " << ref_outputs.size() << ", but got: " << outputs_num;
377 return RET_ERROR;
378 }
379 for (size_t i = 0; i < outputs_num; ++i) {
380 outputs[i] = GetRawPtr(res_mgr, ref_outputs[i]);
381 }
382 } catch (const std::exception &e) {
383 MS_LOG(ERROR) << "FuncGraph compile failed. Error info: " << e.what();
384 return RET_ERROR;
385 }
386 return RET_OK;
387 }
388