• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/c_api/ms/graph.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "c_api/src/helper.h"
20 #include "c_api/src/common.h"
21 #include "c_api/src/utils.h"
22 #include "c_api/src/pass.h"
23 #include "base/base.h"
24 #include "ir/func_graph.h"
25 #include "ir/anf.h"
26 #include "ir/func_graph_cloner.h"
27 #include "utils/ms_context.h"
28 #include "backend/graph_compiler/backend.h"
29 #include "pipeline/jit/ps/pass.h"
30 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
31 
MSFuncGraphCreate(ResMgrHandle res_mgr)32 GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr) {
33   if (res_mgr == nullptr) {
34     MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
35     return nullptr;
36   }
37   auto fg = std::make_shared<FuncGraphImpl>();
38   return GetRawPtr(res_mgr, fg);
39 }
40 
MSFuncGraphLoad(ResMgrHandle res_mgr,const char * file_path)41 GraphHandle MSFuncGraphLoad(ResMgrHandle res_mgr, const char *file_path) {
42   if (res_mgr == nullptr) {
43     MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
44     return nullptr;
45   }
46   try {
47     mindspore::MindIRLoader mind_loader(false, nullptr, 0, "AES-GCM", false);
48     auto fg = mind_loader.LoadMindIR(file_path);
49     if (fg == nullptr) {
50       MS_LOG(ERROR) << "Load funcgraph from MINDIR fail.";
51     }
52     return GetRawPtr(res_mgr, fg);
53   } catch (const std::exception &e) {
54     MS_LOG(ERROR) << "FuncGraph load failed. Error info: " << e.what();
55     return nullptr;
56   }
57 }
58 
MSFuncGraphGetInput(ResMgrHandle res_mgr,ConstGraphHandle graph,size_t i)59 NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
60   if (res_mgr == nullptr || graph == nullptr) {
61     MS_LOG(ERROR) << "Input Handle [res_mgr] or [cnode] is nullptr.";
62     return nullptr;
63   }
64   try {
65     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
66     MS_EXCEPTION_IF_NULL(res_fg);
67     auto fg_inputs = res_fg->get_inputs();
68     if (i >= fg_inputs.size()) {
69       MS_LOG(ERROR) << "Invalid input index, it should be less than " << fg_inputs.size() << ", but got: " << i;
70       return nullptr;
71     }
72     return GetRawPtr(res_mgr, fg_inputs[i]);
73   } catch (const std::exception &e) {
74     MS_LOG(ERROR) << "FuncGraph get inputs failed. Error info: " << e.what();
75     return nullptr;
76   }
77 }
78 
MSFuncGraphGetInputNum(ResMgrHandle res_mgr,ConstGraphHandle graph,STATUS * error)79 size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
80   if (error == nullptr) {
81     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
82     return 0;
83   }
84   if (res_mgr == nullptr || graph == nullptr) {
85     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
86     *error = RET_NULL_PTR;
87     return 0;
88   }
89   size_t input_num;
90   try {
91     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
92     MS_EXCEPTION_IF_NULL(res_fg);
93     input_num = res_fg->get_inputs().size();
94   } catch (const std::exception &e) {
95     MS_LOG(ERROR) << "FuncGraph get input number failed. Error info: " << e.what();
96     *error = RET_ERROR;
97     return 0;
98   }
99   *error = RET_OK;
100   return input_num;
101 }
102 
MSFuncGraphGetInputs(ResMgrHandle res_mgr,ConstGraphHandle graph,NodeHandle inputs[],size_t input_num)103 STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle inputs[], size_t input_num) {
104   if (res_mgr == nullptr || graph == nullptr || inputs == nullptr) {
105     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
106     return RET_NULL_PTR;
107   }
108   try {
109     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
110     MS_EXCEPTION_IF_NULL(res_fg);
111     auto fg_inputs = res_fg->get_inputs();
112     if (fg_inputs.size() != input_num) {
113       MS_LOG(ERROR) << "Invalid input number, it should be: " << fg_inputs.size() << ", but got: " << input_num;
114       return RET_ERROR;
115     }
116     for (size_t i = 0; i < input_num; i++) {
117       inputs[i] = GetRawPtr(res_mgr, fg_inputs[i]);
118     }
119   } catch (const std::exception &e) {
120     MS_LOG(ERROR) << "FuncGraph get inputs failed. Error info: " << e.what();
121     return RET_ERROR;
122   }
123   return RET_OK;
124 }
125 
MSFuncGraphSetOutput(ResMgrHandle res_mgr,GraphHandle graph,ConstNodeHandle op_node,bool force_new_ret)126 STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op_node, bool force_new_ret) {
127   if (res_mgr == nullptr || graph == nullptr || op_node == nullptr) {
128     MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [op_node] is nullptr.";
129     return RET_NULL_PTR;
130   }
131   try {
132     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
133     MS_EXCEPTION_IF_NULL(res_fg);
134     auto res_anfnode = GetSrcPtr<AnfNodePtr>(res_mgr, op_node);
135     MS_EXCEPTION_IF_NULL(res_anfnode);
136     res_fg->set_output(res_anfnode, force_new_ret);
137   } catch (const std::exception &e) {
138     MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
139     return RET_ERROR;
140   }
141   return RET_OK;
142 }
143 
MSFuncGraphSetOutputs(ResMgrHandle res_mgr,GraphHandle graph,Handle const outputs[],size_t output_num,bool force_new_ret)144 STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, Handle const outputs[], size_t output_num,
145                              bool force_new_ret) {
146   if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
147     MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [outputs] is nullptr.";
148     return RET_NULL_PTR;
149   }
150   try {
151     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
152     MS_EXCEPTION_IF_NULL(res_fg);
153     std::vector<AnfNodePtr> out_nodes{NewValueNode(mindspore::prim::kPrimMakeTuple)};
154     mindspore::AbstractBasePtrList abs_list{};
155     for (size_t i = 0; i < output_num; ++i) {
156       auto out_node = GetSrcPtr<AnfNodePtr>(res_mgr, outputs[i]);
157       MS_EXCEPTION_IF_NULL(out_node);
158       out_nodes.push_back(out_node);
159       ConvertConstScalarInputToTensor(out_node);
160       abs_list.push_back(out_node->abstract());
161     }
162     auto make_tuple_cnode = res_fg->NewCNodeInOrder(out_nodes);
163     make_tuple_cnode->set_abstract(std::make_shared<AbstractTupleImpl>(abs_list));
164     res_fg->set_output(make_tuple_cnode, force_new_ret);
165   } catch (const std::exception &e) {
166     MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
167     return RET_ERROR;
168   }
169   return RET_OK;
170 }
171 
MSFuncGraphGetOutput(ResMgrHandle res_mgr,ConstGraphHandle graph,size_t i)172 NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
173   if (res_mgr == nullptr || graph == nullptr) {
174     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
175     return nullptr;
176   }
177   try {
178     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
179     MS_EXCEPTION_IF_NULL(res_fg);
180     auto out_node = res_fg->output();
181     if (IsPrimitiveCNode(out_node, mindspore::prim::kPrimMakeTuple)) {
182       auto out_cnode = out_node->cast<CNodePtr>();
183       MS_EXCEPTION_IF_NULL(out_cnode);
184       auto out_num = out_cnode->size() - 1;
185       if (i >= out_num) {
186         MS_LOG(ERROR) << "Invalid output index, it should be less than " << out_num << ", but got: " << i;
187         return nullptr;
188       }
189       return GetRawPtr(res_mgr, out_cnode->input(i + 1));
190     } else {
191       if (i >= 1) {
192         MS_LOG(ERROR)
193           << "Invalid output index. The graph has only one output, so the output index should be 0, but got: " << i;
194         return nullptr;
195       }
196       return GetRawPtr(res_mgr, out_node);
197     }
198   } catch (const std::exception &e) {
199     MS_LOG(ERROR) << "FuncGraph get output failed. Error info: " << e.what();
200     return nullptr;
201   }
202 }
203 
MSFuncGraphGetOutputNum(ResMgrHandle res_mgr,ConstGraphHandle graph,STATUS * error)204 size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
205   if (error == nullptr) {
206     MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
207     return 0;
208   }
209   if (res_mgr == nullptr || graph == nullptr) {
210     MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [error] is nullptr.";
211     *error = RET_NULL_PTR;
212     return 0;
213   }
214   size_t out_num = 0;
215   try {
216     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
217     MS_EXCEPTION_IF_NULL(res_fg);
218     auto out_node = res_fg->output();
219     if (IsPrimitiveCNode(out_node, mindspore::prim::kPrimMakeTuple)) {
220       auto out_cnode = out_node->cast<CNodePtr>();
221       MS_EXCEPTION_IF_NULL(out_cnode);
222       out_num = out_cnode->size() - 1;
223     } else {
224       out_num = 1;
225     }
226   } catch (const std::exception &e) {
227     MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
228     *error = RET_ERROR;
229     return 0;
230   }
231   return out_num;
232 }
233 
MSFuncGraphGetOutputs(ResMgrHandle res_mgr,ConstGraphHandle graph,NodeHandle outputs[],size_t output_num)234 STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle outputs[], size_t output_num) {
235   if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
236     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
237     return RET_NULL_PTR;
238   }
239   try {
240     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
241     MS_EXCEPTION_IF_NULL(res_fg);
242     size_t out_num = 0;
243     auto out_node = res_fg->output();
244     auto out_cnode = out_node->cast<CNodePtr>();
245     MS_EXCEPTION_IF_NULL(out_cnode);
246     if (IsPrimitiveCNode(out_node, mindspore::prim::kPrimMakeTuple)) {
247       out_num = out_cnode->size() - 1;
248     } else {
249       out_num = 1;
250     }
251     if (out_num != output_num) {
252       MS_LOG(ERROR) << "Invalid output number, it should be: " << out_num << ", but got: " << output_num;
253       return RET_ERROR;
254     }
255     for (size_t i = 0; i < output_num; i++) {
256       outputs[i] = GetRawPtr(res_mgr, out_cnode->input(i + 1));
257     }
258   } catch (const std::exception &e) {
259     MS_LOG(ERROR) << "FuncGraph get inputs failed. Error info: " << e.what();
260     return RET_ERROR;
261   }
262   return RET_OK;
263 }
264 
MSFuncGraphReplace(ResMgrHandle res_mgr,GraphHandle graph,ConstNodeHandle old_node,ConstNodeHandle new_node)265 STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle old_node, ConstNodeHandle new_node) {
266   if (res_mgr == nullptr || graph == nullptr || old_node == nullptr || new_node == nullptr) {
267     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [old_node] or [new_node] is nullptr.";
268     return RET_NULL_PTR;
269   }
270   try {
271     auto res_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
272     MS_EXCEPTION_IF_NULL(res_fg);
273     auto manager_ptr = mindspore::Manage(res_fg, true);
274     MS_EXCEPTION_IF_NULL(manager_ptr);
275     auto res_old_anfnode = GetSrcPtr<AnfNodePtr>(res_mgr, old_node);
276     MS_EXCEPTION_IF_NULL(res_old_anfnode);
277     auto res_new_anfnode = GetSrcPtr<AnfNodePtr>(res_mgr, new_node);
278     MS_EXCEPTION_IF_NULL(res_new_anfnode);
279     (void)manager_ptr->Replace(res_old_anfnode, res_new_anfnode);
280   } catch (const std::exception &e) {
281     MS_LOG(ERROR) << "FuncGraph replace failed. Error info: " << e.what();
282     return RET_ERROR;
283   }
284   return RET_OK;
285 }
286 
MSFuncGraphCompile(ResMgrHandle res_mgr,GraphHandle graph,OptPassID * opt_pass,size_t pass_num)287 STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph, OptPassID *opt_pass, size_t pass_num) {
288   if (res_mgr == nullptr || graph == nullptr) {
289     MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
290     return RET_NULL_PTR;
291   }
292   auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
293   auto context_ptr = mindspore::MsContext::GetInstance();
294   try {
295     auto func_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
296     MS_EXCEPTION_IF_NULL(func_graph);
297     auto fg_mgr = mindspore::MakeManager();
298     MS_EXCEPTION_IF_NULL(fg_mgr);
299     fg_mgr->AddFuncGraph(func_graph, true);
300     func_graph->set_manager(fg_mgr);
301     (void)mindspore::pipeline::AutoMonad(func_graph);
302     (void)mindspore::LiftingClone(func_graph);
303     context_ptr->Refresh();
304     std::string backend_name = context_ptr->backend_policy();
305     std::string target = context_ptr->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
306     uint32_t device_id = context_ptr->get_param<uint32_t>(mindspore::MS_CTX_DEVICE_ID);
307     auto backend = res_mgr_ptr->GetBackendFromCache(target);
308     if (backend == nullptr) {
309       backend = std::make_shared<mindspore::compile::MindRTBackend>(backend_name, target, device_id);
310       res_mgr_ptr->CacheBackend(target, backend);
311     }
312     if (target == mindspore::kAscendDevice &&
313         context_ptr->get_param<int>(mindspore::MS_CTX_EXECUTION_MODE) == mindspore::kPynativeMode) {
314       backend->set_is_multi_graph_sink(false);
315     }
316     for (size_t i = 0; i < pass_num; i++) {
317       auto iter = kPassEnumToFuncMap.find(opt_pass[i]);
318       if (iter == kPassEnumToFuncMap.end()) {
319         MS_LOG(ERROR) << "Unsupported optimization pass: " << opt_pass[i];
320         return RET_ERROR;
321       }
322       auto pass_func = iter->second;
323       auto success = pass_func(func_graph);
324       if (!success) {
325         MS_LOG(WARNING) << "Run optimization pass failed! Pass ID: " << opt_pass[i];
326       }
327     }
328     context_ptr->set_param<bool>(mindspore::MS_CTX_ENABLE_MINDRT, true);
329     auto actor_info = backend->CompileGraphs(func_graph);
330     res_mgr_ptr->SetResult(mindspore::pipeline::kOutput, actor_info);
331   } catch (const std::exception &e) {
332     MS_LOG(ERROR) << "FuncGraph compile failed. Error info: " << e.what();
333     return RET_ERROR;
334   }
335   return RET_OK;
336 }
337 
MSFuncGraphRun(ResMgrHandle res_mgr,GraphHandle graph,Handle const inputs[],size_t input_num,TensorHandle outputs[],size_t outputs_num)338 STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, Handle const inputs[], size_t input_num,
339                       TensorHandle outputs[], size_t outputs_num) {
340   if (res_mgr == nullptr || inputs == nullptr || outputs == nullptr) {
341     MS_LOG(ERROR) << "Input Handle [res_mgr] or [inputs] or [outputs] is nullptr.";
342     return RET_NULL_PTR;
343   }
344   mindspore::VectorRef args;
345   for (size_t i = 0; i < input_num; i++) {
346     auto in_arg = GetSrcPtr<ValuePtr>(res_mgr, inputs[i]);
347     if (in_arg == nullptr) {
348       MS_LOG(ERROR) << "Invalid input. Index: " << i;
349       return RET_NULL_PTR;
350     }
351     args.push_back(in_arg);
352   }
353   auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
354   auto raw_info = res_mgr_ptr->GetResult(mindspore::pipeline::kOutput);
355   mindspore::VectorRef out_vec;
356   try {
357     auto context_ptr = mindspore::MsContext::GetInstance();
358     std::string target = context_ptr->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
359     auto mindrt_bc_ptr = res_mgr_ptr->GetBackendFromCache(target);
360     MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
361     auto func_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
362     MS_EXCEPTION_IF_NULL(func_graph);
363     auto params_anf = func_graph->parameters();
364     for (auto p : params_anf) {
365       auto param = p->cast<ParameterPtr>();
366       if (param->has_default()) {
367         auto value_ptr = param->default_param();
368         auto tensor_ptr = value_ptr->cast<TensorPtr>();
369         args.push_back(tensor_ptr);
370       }
371     }
372     const auto actor_info = raw_info.cast<mindspore::compile::ActorInfo>();
373     mindrt_bc_ptr->RunGraph(actor_info, args, &out_vec);
374     const std::vector<TensorPtr> &ref_outputs = ConvertOutputToTensor(out_vec);
375     if (ref_outputs.size() != outputs_num) {
376       MS_LOG(ERROR) << "Invalid outputs number, it should be: " << ref_outputs.size() << ", but got: " << outputs_num;
377       return RET_ERROR;
378     }
379     for (size_t i = 0; i < outputs_num; ++i) {
380       outputs[i] = GetRawPtr(res_mgr, ref_outputs[i]);
381     }
382   } catch (const std::exception &e) {
383     MS_LOG(ERROR) << "FuncGraph compile failed. Error info: " << e.what();
384     return RET_ERROR;
385   }
386   return RET_OK;
387 }
388