• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/pipeline.h"
20 
21 #include <memory>
22 #include <sstream>
23 #include <map>
24 #include <unordered_map>
25 #include <cstdlib>
26 #include <algorithm>
27 #include <iomanip>
28 
29 #include "ir/param_info.h"
30 #include "pipeline/jit/pass.h"
31 #include "pipeline/jit/parse/data_converter.h"
32 #include "frontend/optimizer/ad/dfunctor.h"
33 #include "pipeline/jit/static_analysis/async_eval_result.h"
34 #include "debug/anf_ir_dump.h"
35 #include "debug/dump_proto.h"
36 #include "debug/anf_ir_utils.h"
37 #include "debug/common.h"
38 #include "utils/config_manager.h"
39 #include "utils/convert_utils.h"
40 #include "utils/convert_utils_py.h"
41 #include "utils/context/context_extends.h"
42 #include "vm/segment_runner.h"
43 #include "frontend/parallel/context.h"
44 #include "frontend/parallel/graph_util/get_parallel_info.h"
45 #include "runtime/device/kernel_runtime_manager.h"
46 #include "backend/session/executor_manager.h"
47 #include "debug/trace.h"
48 #include "debug/draw.h"
49 #include "pipeline/pynative/pynative_execute.h"
50 #include "frontend/optimizer/py_pass_manager.h"
51 #include "pybind_api/pybind_patch.h"
52 #include "utils/shape_utils.h"
53 #include "utils/info.h"
54 #include "load_mindir/load_model.h"
55 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
56 #include "runtime/hardware/device_context_manager.h"
57 #include "utils/crypto.h"
58 #include "utils/comm_manager.h"
59 #if ((defined ENABLE_CPU) && (!defined _WIN32))
60 #include "ps/constants.h"
61 #include "ps/util.h"
62 #include "ps/worker.h"
63 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
64 #include "ps/ps_cache/ps_cache_manager.h"
65 #include "fl/server/server.h"
66 #include "fl/worker/fl_worker.h"
67 #endif
68 
69 #if ((defined ENABLE_GE) || (defined ENABLE_D))
70 #include "pipeline/jit/pipeline_ge.h"
71 #include "transform/graph_ir/convert.h"
72 #include "transform/graph_ir/df_graph_manager.h"
73 #include "transform/graph_ir/op_adapter_map.h"
74 #include "runtime/device/ascend/profiling/profiling_manager.h"
75 #include "runtime/device/ascend/distribute/ascend_collective.h"
76 #endif
77 #ifdef ENABLE_DUMP_IR
78 #include "debug/rdr/running_data_recorder.h"
79 #include "debug/rdr/recorder_manager.h"
80 #endif
81 
82 namespace mindspore {
83 // namespace to support intermediate representation definition
84 namespace pipeline {
85 using Tensor = mindspore::tensor::Tensor;
86 using MetaTensor = mindspore::tensor::MetaTensor;
87 using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
88 using mindspore::abstract::AbstractTensor;
89 using mindspore::abstract::AbstractTensorPtr;
90 using mindspore::abstract::AbstractTuple;
91 using mindspore::abstract::AbstractTuplePtr;
92 
93 #ifdef ENABLE_D
94 #ifndef ENABLE_SECURITY
95 using mindspore::device::ascend::ProfilingManager;
96 #endif
97 using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
98 #endif
99 
100 const char IR_TYPE_ANF[] = "anf_ir";
101 const char IR_TYPE_ONNX[] = "onnx_ir";
102 const char IR_TYPE_MINDIR[] = "mind_ir";
103 
104 GraphExecutorPyPtr GraphExecutorPy::executor_ = nullptr;
105 std::mutex GraphExecutorPy::instance_lock_;
106 #ifdef ENABLE_DEBUGGER
107 bool GraphExecutorPy::debugger_terminate_ = false;
108 bool GraphExecutorPy::exit_success_ = false;
109 #endif
110 
111 std::unordered_map<abstract::AbstractBasePtrList, uint64_t, abstract::AbstractBasePtrListHasher,
112                    abstract::AbstractBasePtrListEqual>
113   g_args_cache;
114 
115 namespace {
116 constexpr char kCompileCacheFilePath[] = "compile_cache.mindir";
117 #ifdef ENABLE_DUMP_IR
GetBaseNameForIR(int64_t stage_idx,const std::string & action_name)118 std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
119   std::ostringstream oss;
120   int spaces = 2;
121   oss << std::setfill('0') << std::setw(spaces) << stage_idx << "_" << action_name;
122   return oss.str();
123 }
124 #endif
125 
ArgsToAbstract(const ValuePtr & value)126 AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
127   MS_EXCEPTION_IF_NULL(value);
128   bool broaden = value->isa<MetaTensor>() ||
129                  (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
130 
131   return abstract::FromValue(value, broaden);
132 }
133 
CheckArgValid(const py::handle & arg)134 bool CheckArgValid(const py::handle &arg) {
135   if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
136     auto vector_arg = py::cast<py::list>(arg);
137     return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid);
138   }
139 
140   if (py::isinstance<py::dict>(arg)) {
141     auto dict_arg = py::cast<py::dict>(arg);
142     return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
143   }
144 
145   return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
146          py::isinstance<Number>(arg) || (py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
147 }
148 
GetCompileExceptionInfo()149 std::string GetCompileExceptionInfo() {
150   std::ostringstream oss;
151   trace::GetTraceStackInfo(oss);
152   return oss.str();
153 }
154 
SetLoopCount(const ResourcePtr & resource)155 void SetLoopCount(const ResourcePtr &resource) {
156   MS_EXCEPTION_IF_NULL(resource);
157   auto func_graph = resource->func_graph();
158   if (func_graph != nullptr && func_graph->manager() != nullptr) {
159     auto manager = func_graph->manager();
160     size_t graph_nums = manager->func_graphs().size();
161     int64_t loop_size = ConfigManager::GetInstance().iter_num();
162     const auto context_ptr = MsContext::GetInstance();
163     if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
164       resource->set_vm_loop(!context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK), loop_size);
165     } else if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
166       bool run_with_mind_rt = graph_nums == 1 || context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT);
167       resource->set_vm_loop(!run_with_mind_rt, loop_size);
168     }
169     MS_LOG(INFO) << "Change vm_loop_flag to " << resource->vm_loop_flag() << ", set loop_size to " << loop_size;
170   }
171 }
172 
GetCachedFuncGraph(const ResourcePtr & resource,const std::string & queue_name)173 void GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_name) {
174   MS_EXCEPTION_IF_NULL(resource);
175   auto realpath = Common::CreatePrefixPath(kCompileCacheFilePath);
176   if (!realpath.has_value()) {
177     MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
178   }
179   std::ifstream f(realpath.value());
180   bool cache_file_existed = f.good();
181   f.close();
182   if (!cache_file_existed) {
183     MS_LOG(WARNING) << "The compilation cache file '" << realpath.value()
184                     << "' dose not exist. Execute all the compilation actions.";
185     return;
186   }
187   MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only.";
188   FuncGraphPtr fg = mindspore::LoadMindIR(realpath.value());
189   if (fg == nullptr) {
190     MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value();
191   }
192   FuncGraphManagerPtr mng = fg->manager();
193   if (mng == nullptr) {
194     auto res_mng = resource->manager();
195     MS_EXCEPTION_IF_NULL(res_mng);
196     res_mng->AddFuncGraph(fg);
197     fg->set_manager(res_mng);
198   }
199   auto cnodes = fg->GetOrderedCnodes();
200   for (auto cnode : cnodes) {
201     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
202     if (prim != nullptr && prim->HasAttr("shared_name")) {
203       prim->set_attr("shared_name", MakeValue(queue_name));
204       break;
205     }
206   }
207   resource->set_func_graph(fg);
208 }
209 
CacheFuncGraph(const ResourcePtr & resource)210 void CacheFuncGraph(const ResourcePtr &resource) {
211   MS_EXCEPTION_IF_NULL(resource);
212   auto realpath = Common::CreatePrefixPath(kCompileCacheFilePath);
213   if (!realpath.has_value()) {
214     MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
215   }
216 
217   ChangeFileMode(realpath.value(), S_IRWXU);
218   std::ofstream fout(realpath.value());
219   if (!fout.is_open()) {
220     MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
221   }
222   FuncGraphPtr fg = resource->func_graph();
223   mind_ir::ModelProto fg_model = GetBinaryProto(fg, true);
224   if (!fg_model.SerializeToOstream(&fout)) {
225     MS_LOG(EXCEPTION) << "Failed to cache the graph to file " << realpath.value();
226   }
227   fout.close();
228   ChangeFileMode(realpath.value(), S_IRUSR);
229 }
230 }  // namespace
231 
CheckArgsValid(const py::tuple & args)232 void CheckArgsValid(const py::tuple &args) {
233   for (size_t i = 0; i < args.size(); i++) {
234     if (!CheckArgValid(args[i])) {
235       MS_EXCEPTION(TypeError)
236         << "The inputs types of the outermost network support bool, int, float, None, tensor, "
237            "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
238            "and tuple or list containing only these types, and dict whose values are these types, but the "
239         << i << "th arg type is " << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
240     }
241   }
242 }
243 
GenerateArgumentsKey(const std::unordered_map<std::string,py::object> & args)244 py::object GenerateArgumentsKey(const std::unordered_map<std::string, py::object> &args) {
245   MS_LOG(DEBUG) << "GenerateArgumentsKey args size:" << args.size();
246   abstract::AbstractBasePtrList args_spec;
247 
248   for (const auto &arg : args) {
249     if (py::isinstance<py::module>(arg.second)) {
250       MS_LOG(EXCEPTION) << "GenerateArgumentsKey failed, argument input should not be py::module";
251     }
252     ValuePtr converted = nullptr;
253     if (!parse::ConvertData(arg.second, &converted)) {
254       MS_LOG(EXCEPTION) << "GenerateArgumentsKey convert arg failed";
255     }
256     args_spec.push_back(ArgsToAbstract(converted));
257   }
258 
259   uint64_t key;
260   auto iter = g_args_cache.find(args_spec);
261   if (iter == g_args_cache.end()) {
262     static uint64_t key_counter = 0;
263     key = key_counter;
264     ++key_counter;
265     g_args_cache[args_spec] = key;
266     MS_LOG(INFO) << "Generate a new compile key for new args, key: " << key;
267   } else {
268     key = iter->second;
269   }
270   return py::int_(key);
271 }
272 
VerifyInputSignature(const py::list & input_signature,const py::tuple & inputs)273 py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs) {
274   MS_LOG(DEBUG) << "Verify args size:" << inputs.size();
275   if (inputs.size() != input_signature.size()) {
276     MS_LOG(ERROR) << "Signature size not equal to args size";
277     return false;
278   }
279 
280   size_t count = 0;
281   for (auto arg_obj : inputs) {
282     if (py::isinstance<Tensor>(arg_obj)) {
283       MS_LOG(DEBUG) << "Verify Tensor";
284       auto m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
285       if (m_tensor == nullptr) {
286         MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
287         return false;
288       }
289       auto sig = input_signature[count].cast<std::shared_ptr<MetaTensor>>();
290       ShapeVector sig_shape = sig->shape();
291       TypePtr sig_type = sig->Dtype();
292 
293       ShapeVector tensor_shape = m_tensor->shape_c();
294       if (tensor_shape != sig_shape) {
295         MS_LOG(ERROR) << "Python input shape is incompatible with input_signature";
296         return false;
297       }
298 
299       if (*m_tensor->Dtype() != *sig_type) {
300         MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature("
301                       << sig_type->ToString() << ")";
302         return false;
303       }
304     }
305     count++;
306   }
307 
308   return true;
309 }
310 
PipelineRDRProcess(const FuncGraphPtr & graph,const std::vector<ActionItem> & actions,const ActionItem & action,size_t i)311 void PipelineRDRProcess(const FuncGraphPtr &graph, const std::vector<ActionItem> &actions, const ActionItem &action,
312                         size_t i) {
313   MS_LOG(INFO) << "Recording FuncGraph in pipeline using RDR.";
314   std::string name = GetBaseNameForIR(SizeToLong(i), action.first);
315   if (graph != nullptr) {
316     auto graph_clone = BasicClone(graph);
317     if (graph_clone != nullptr) {
318       DumpGraphParams dump_params = {false, static_cast<int>(kTopStack)};
319       if (i == actions.size()) {
320         dump_params.dump_mode = static_cast<int>(kWholeStack);
321       }
322       (void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph_clone, dump_params, ".ir");
323     } else {
324       MS_LOG(WARNING) << "Clone FuncGraph failed in pipeline, no FuncGraph recording in RDR.";
325     }
326   } else {
327     MS_LOG(WARNING) << "Pipeline Resource has no FuncGraph, no FuncGraph recording in RDR";
328   }
329   MS_LOG(INFO) << "Recording FuncGraph in pipeline end.";
330 }
331 
GraphExecutorPy()332 GraphExecutorPy::GraphExecutorPy() {}
333 
GetResource(const std::string & phase)334 ResourcePtr GraphExecutorPy::GetResource(const std::string &phase) {
335   MS_LOG(DEBUG) << "Phase size:" << info_.size();
336   if (info_.count(phase) == 0) {
337     return nullptr;
338   }
339   return info_[phase]->resource;
340 }
341 
GetFuncGraph(const std::string & phase)342 FuncGraphPtr GraphExecutorPy::GetFuncGraph(const std::string &phase) {
343   if (info_.count(phase) == 0) {
344     MS_LOG(EXCEPTION) << "No executor info. found for phase: " << phase;
345   }
346   return info_[phase]->func_graph;
347 }
348 
GetGradGraph(const std::string & phase)349 FuncGraphPtr GraphExecutorPy::GetGradGraph(const std::string &phase) {
350   if (phase.empty()) {
351     MS_LOG(EXCEPTION) << "The input phase is empty.";
352   }
353   if (info_.count(phase) == 0) {
354     MS_LOG(EXCEPTION) << "No phase in executor:" << phase;
355   }
356 
357   auto execute_info = info_[phase];
358   MS_EXCEPTION_IF_NULL(execute_info);
359   auto grad_graph = execute_info->grad_graph;
360   MS_EXCEPTION_IF_NULL(grad_graph);
361   return grad_graph;
362 }
363 
SetGradGraph(const FuncGraphPtr & grad_graph,const std::string & phase)364 void GraphExecutorPy::SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase) {
365   if (phase.empty()) {
366     MS_LOG(EXCEPTION) << "The input phase is empty.";
367   }
368   if (info_.count(phase) == 0) {
369     MS_LOG(EXCEPTION) << "No phase in executor: " << phase;
370   }
371 
372   auto execute_info = info_[phase];
373   MS_EXCEPTION_IF_NULL(execute_info);
374   if (execute_info->grad_graph != nullptr) {
375     MS_LOG(DEBUG) << "The grad graph has existed, phase is: " << phase;
376   }
377   MS_EXCEPTION_IF_NULL(grad_graph);
378   execute_info->grad_graph = grad_graph;
379 }
380 
GetVmEvalFunc(const std::string & phase)381 compile::VmEvalFuncPtr GraphExecutorPy::GetVmEvalFunc(const std::string &phase) {
382   ResourcePtr res = GetResource(phase);
383   MS_EXCEPTION_IF_NULL(res);
384   if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is<compile::VmEvalFuncPtr>()) {
385     return res->results()[kOutput].cast<compile::VmEvalFuncPtr>();
386   }
387   MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput;
388   return nullptr;
389 }
390 
HasCompiled(const std::string & phase) const391 bool GraphExecutorPy::HasCompiled(const std::string &phase) const {
392   if (info_.count(phase) == 0) {
393     return false;
394   }
395   return true;
396 }
397 
GetFuncGraphProto(const std::string & phase,const std::string & ir_type)398 py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) {
399   FuncGraphPtr fg_ptr = GetFuncGraph(phase);
400   if (fg_ptr == nullptr) {
401     for (auto &item : info_) {
402       MS_LOG(DEBUG) << "Phase key is: " << item.first;
403     }
404     MS_LOG(EXCEPTION) << "Can not find func graph " << phase;
405   }
406 
407   if (ir_type == IR_TYPE_ANF) {
408     std::string proto_str = GetFuncGraphProtoString(fg_ptr);
409     if (proto_str.empty()) {
410       MS_LOG(EXCEPTION) << "Export ANF format model failed.";
411     }
412     return proto_str;
413   }
414 
415   if (ir_type == IR_TYPE_ONNX) {
416     std::string proto_str = GetOnnxProtoString(fg_ptr);
417     if (proto_str.empty()) {
418       MS_LOG(EXCEPTION) << "Export ONNX format model failed.";
419     }
420     return proto_str;
421   }
422 
423   if (ir_type == IR_TYPE_MINDIR) {
424     std::string proto_str = GetBinaryProtoString(fg_ptr);
425     if (proto_str.empty()) {
426       MS_LOG(EXCEPTION) << "Export MINDIR format model failed.";
427     }
428     return proto_str;
429   }
430 
431   MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
432 }
433 
GetParameterLayout(const std::string & phase)434 py::dict GraphExecutorPy::GetParameterLayout(const std::string &phase) {
435   MS_LOG(DEBUG) << "GetParameterLayout!";
436   std::string layout_graph = phase + kStepParallelGraph;
437   auto graph = GetFuncGraph(layout_graph);
438   return mindspore::parallel::GetParameterLayout(graph);
439 }
440 
GetCNodeStrategy(const std::string & phase)441 py::dict GraphExecutorPy::GetCNodeStrategy(const std::string &phase) {
442   MS_LOG(DEBUG) << "GetCNodeStrategy!";
443   return stra_dict_[phase];
444 }
445 
GetParallelParameterNameList(const std::string & phase)446 py::list GraphExecutorPy::GetParallelParameterNameList(const std::string &phase) {
447   std::string param_graph = phase + kStepParallelGraph;
448   auto graph = GetFuncGraph(param_graph);
449   return mindspore::parallel::GetParallelParameterNameList(graph);
450 }
451 
SetCNodeStrategy(const std::string & name,const parallel::Strategys & strategy)452 void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
453   MS_LOG(DEBUG) << "SetCNodeStrategy!";
454   stra_dict_[phase_][py::str(name)] = strategy;
455 }
456 
GetNumOpsInfo(const std::string & phase)457 size_t GraphExecutorPy::GetNumOpsInfo(const std::string &phase) {
458   MS_LOG(DEBUG) << "GetNumOpsInfo!";
459   return phase_to_num_op_info_[phase];
460 }
461 
SetNumOpsInfo(size_t num_ops)462 void GraphExecutorPy::SetNumOpsInfo(size_t num_ops) {
463   MS_LOG(DEBUG) << "SetNumOpsInfo!";
464   phase_to_num_op_info_[phase_] = num_ops;
465 }
466 
GetAllreduceFusion(const std::string & phase)467 py::dict GraphExecutorPy::GetAllreduceFusion(const std::string &phase) {
468   MS_LOG(INFO) << "GetAllreduceFusion!";
469   auto graph = GetFuncGraph(phase);
470   return mindspore::parallel::GetAllreduceFusion(graph);
471 }
472 
473 // Not support multi thread, not support nested call too.
474 // Here using nested_called flg to avoid nested call.
DelNetRes(const std::string & id)475 void GraphExecutorPy::DelNetRes(const std::string &id) {
476   static bool nested_called = false;
477   if (nested_called) {
478     return;
479   }
480   nested_called = true;
481 #ifdef ENABLE_GE
482   FinalizeBackend();
483 #else
484   ConfigManager::GetInstance().ResetIterNum();
485 #endif
486   if (executor_ != nullptr) {
487     bool flag = false;
488     auto tmp_info = info_;
489     for (auto &item : tmp_info) {
490       if (item.first.find(id) != string::npos) {
491         MS_LOG(DEBUG) << "Delete network res:" << item.first;
492         item.second = nullptr;
493         (void)info_.erase(item.first);
494         flag = true;
495       }
496     }
497 
498     MS_LOG(DEBUG) << "Delete flag:" << flag;
499 #ifdef ENABLE_GE
500     if (flag && info_.size() == 0) {
501       // because Ge only support one Session exist at the same time ,so we delete the old one
502       transform::DfGraphManager::GetInstance().DeleteGraphRunner();
503       transform::DfGraphManager::GetInstance().EraseAnfGraph();
504       transform::DfGraphManager::GetInstance().DeleteGeSession();
505     }
506 #endif
507   }
508   nested_called = false;
509 }
510 
ClearRes()511 void GraphExecutorPy::ClearRes() {
512   MS_LOG(INFO) << "Clean executor resource!";
513   executor_ = nullptr;
514 }
515 
~GraphExecutorPy()516 GraphExecutorPy::~GraphExecutorPy() {
517   MS_LOG(INFO) << "Release Executor!";
518   ConfigManager::GetInstance().ResetConfig();
519 }
520 
GetWeightInfo(const CNodePtr & root_node,const AnfNodePtr & weight_node,std::map<std::string,std::pair<PrimitivePyAdapterPtr,std::string>> * fake_quant_table)521 void GraphExecutorPy::GetWeightInfo(
522   const CNodePtr &root_node, const AnfNodePtr &weight_node,
523   std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) {
524   MS_EXCEPTION_IF_NULL(root_node);
525   MS_EXCEPTION_IF_NULL(fake_quant_table);
526   std::string weight_name;
527   auto x = root_node->input(1);
528   MS_EXCEPTION_IF_NULL(x);
529   if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
530     weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
531   } else {
532     auto para = weight_node->cast<ParameterPtr>();
533     MS_EXCEPTION_IF_NULL(para);
534     weight_name = para->name();
535   }
536   // find the fakequant from input
537   int64_t count = 0;
538   const int64_t max_depth = 5;
539   CNodePtr cnode = nullptr;
540   auto is_quant_cnode = [](const AnfNodePtr &node) {
541     return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
542            IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) ||
543            IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerLayer) ||
544            IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerChannel);
545   };
546   while (!is_quant_cnode(x)) {
547     if (count >= max_depth) {
548       break;
549     }
550     cnode = x->cast<CNodePtr>();
551     if (cnode == nullptr || cnode->size() <= 1) {
552       break;
553     }
554     x = cnode->input(1);
555     count += 1;
556   }
557   if (x->isa<Parameter>() || IsPrimitiveCNode(x, prim::kPrimLoad)) {
558     (*fake_quant_table)[weight_name] = std::make_pair(nullptr, "input");
559   }
560   // get the fakequant parameter minq's name
561   if (!is_quant_cnode(x)) {
562     return;
563   }
564   cnode = x->cast<CNodePtr>();
565   constexpr size_t expect_input_size = 4;
566   if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != expect_input_size) {
567     return;
568   }
569   const size_t fakequant_index = 2;
570   auto fakequant_min_node = cnode->input(fakequant_index);
571   if (!fakequant_min_node->isa<Parameter>() && !IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
572     return;
573   }
574   std::string fakequant_min_node_name;
575   if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
576     fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
577   } else {
578     auto param = fakequant_min_node->cast<ParameterPtr>();
579     MS_EXCEPTION_IF_NULL(param);
580     fakequant_min_node_name = param->name();
581   }
582   auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
583   MS_EXCEPTION_IF_NULL(quant_op_value);
584   if (!quant_op_value->isa<PrimitivePy>()) {
585     return;
586   }
587   auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
588   (*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name);
589 }
590 
FetchInfoForQuantExport(const std::string & phase)591 std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> GraphExecutorPy::FetchInfoForQuantExport(
592   const std::string &phase) {
593   FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
594   MS_EXCEPTION_IF_NULL(func_graph);
595   MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase << ")!";
596   std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> fake_quant_table;
597   auto filter = [](const AnfNodePtr &node) {
598     return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
599              IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
600   };
601   std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
602   auto is_quant_cnode = [](const AnfNodePtr &node) {
603     return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
604            IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) ||
605            IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerLayer) ||
606            IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerChannel);
607   };
608   const size_t root_node_size = 3;
609   const size_t weight_index = 2;
610   for (const auto &node : nodes) {
611     auto root_node = node->cast<CNodePtr>();
612     if (root_node == nullptr || root_node->size() != root_node_size) {
613       continue;
614     }
615     auto weight = root_node->input(weight_index);
616     if (!is_quant_cnode(weight)) {
617       auto tuple_node = weight->cast<CNodePtr>();
618       if (tuple_node != nullptr) {
619         auto fake_node = tuple_node->input(1);
620         if (!is_quant_cnode(fake_node)) {
621           continue;
622         } else {
623           weight = fake_node;
624         }
625       }
626     }
627     // get parameter weight's name
628     auto cnode = weight->cast<CNodePtr>();
629     MS_EXCEPTION_IF_NULL(cnode);
630     auto weight_node = cnode->input(weight_index);
631     if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
632       continue;
633     }
634     GetWeightInfo(root_node, weight_node, &fake_quant_table);
635   }
636   return fake_quant_table;
637 }
638 
SaveCompiledGraph(const std::string & phase)639 void GraphExecutorPy::SaveCompiledGraph(const std::string &phase) {
640   // save the graph to GraphExecutorPy
641   FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
642   MS_EXCEPTION_IF_NULL(func_graph);
643   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
644   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
645 
646   MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase << ")!";
647   info_[phase]->func_graph = func_graph;
648   if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) &&
649       ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
650     MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
651     func_graph = info_[phase]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();
652     ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
653     std::string layout_graph = phase + kStepParallelGraph;
654     executor_info->func_graph = func_graph;
655     info_[layout_graph] = executor_info;
656   } else {
657     MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!";
658   }
659   MS_LOG(INFO) << "End save compiled func graph!";
660 }
661 
GetGeBackendPolicy() const662 void GraphExecutorPy::GetGeBackendPolicy() const {
663   auto ms_context = MsContext::GetInstance();
664   MS_EXCEPTION_IF_NULL(ms_context);
665   std::string backend = ms_context->backend_policy();
666   if (backend != "ge") {
667     MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!";
668   }
669 }
670 
IsPhaseExportAir(const std::string & phase)671 bool IsPhaseExportAir(const std::string &phase) {
672   auto phase_to_export = "export.air";
673   return phase.rfind(phase_to_export) != std::string::npos;
674 }
675 
IsPhaseTrain(const std::string & phase)676 bool IsPhaseTrain(const std::string &phase) {
677   const std::string phase_to_train = "train";
678   return phase.rfind(phase_to_train) != std::string::npos;
679 }
680 
IsPhaseLoadFromMindIR(const std::string & phase)681 bool IsPhaseLoadFromMindIR(const std::string &phase) {
682   const std::string mindir_graph = "graph_load_from_mindir";
683   return phase.rfind(mindir_graph) != std::string::npos;
684 }
685 
GetPipeline(const ResourcePtr & resource,const std::string & phase,bool use_vm)686 std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::string &phase, bool use_vm) {
687   MS_EXCEPTION_IF_NULL(resource);
688   bool is_air = IsPhaseExportAir(phase);
689 
690   std::string backend = MsContext::GetInstance()->backend_policy();
691 
692 #if ((defined ENABLE_CPU) && (!defined _WIN32))
693   const std::string &server_mode = ps::PSContext::instance()->server_mode();
694   if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
695       ps::PSContext::instance()->is_server()) {
696     return ServerPipeline();
697   }
698   if (ps::PSContext::instance()->is_server()) {
699     resource->results()[kBackend] = compile::CreateBackend();
700     return PServerPipeline();
701   }
702   if (ps::PSContext::instance()->is_scheduler()) {
703     return PSchedulerPipeline();
704   }
705 #endif
706 
707   if (use_vm && backend != "ge" && !is_air) {
708     compile::SetMindRTEnable();
709     // Create backend.
710     auto backend_ptr = compile::CreateBackend();
711 #ifdef ENABLE_DEBUGGER
712     // Connect session to debugger
713     backend_ptr->SetDebugger();
714 #endif
715     resource->results()[kBackend] = backend_ptr;
716     // If the 'use_frontend_compile_cache' context has been set true and the cache is read successfully,
717     // do the backend actions only.
718     if (IsPhaseTrain(phase) && MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE) &&
719         resource->func_graph() != nullptr) {
720       return BackendPipeline();
721     }
722     if (IsPhaseLoadFromMindIR(phase)) {
723       return MindIRPipeline();
724     }
725     return VmPipeline();
726   }
727   return GePipeline();
728 }
729 
CompileInner(const py::object & source_obj,const py::tuple & args,const py::object & phase_obj,bool use_vm,const std::string & queue_name)730 bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj,
731                                    bool use_vm, const std::string &queue_name) {
732   // Check if the phase is valid.
733   if ((!py::isinstance<py::str>(phase_obj))) {
734     MS_LOG(ERROR) << "The `phase` must be string.";
735     return false;
736   }
737   // Check if the function or net is valid.
738   if (py::isinstance<py::none>(source_obj)) {
739     MS_LOG(ERROR) << "The source object to compile should not be None.";
740     return false;
741   }
742   // Check if the args of function or net is valid.
743   CheckArgsValid(args);
744 
745   auto phase = py::cast<std::string>(phase_obj);
746   MS_LOG(INFO) << "Start compiling, phase: " << phase << ".";
747   MS_LOG(DEBUG) << "Compiling source: {" << py::str(source_obj)
748                 << "}\n\n Args: " << py::str(const_cast<py::tuple &>(args));
749 
750 #ifdef ENABLE_GE
751   GetGeBackendPolicy();
752 #endif
753   ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
754   ResourcePtr resource = std::make_shared<Resource>(source_obj);
755 
756   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE)) {
757 #ifdef ENABLE_PROFILE
758     double t1 = GetTime();
759 #endif
760     GetCachedFuncGraph(resource, queue_name);
761 #ifdef ENABLE_PROFILE
762     double t2 = GetTime();
763     MsProfile::StatTime("LoadCachedFuncGraph", t2 - t1);
764 #endif
765   }
766 
767   phase_ = phase;
768   auto actions = GetPipeline(resource, phase, use_vm);
769   std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(actions, phase));
770 
771   // Get the parameters items and add the value to args_spec.
772   abstract::AbstractBasePtrList args_spec;
773   std::size_t size = args.size();
774   for (std::size_t i = 0; i < size; i++) {
775     ValuePtr converted = nullptr;
776     bool succ = parse::ConvertData(args[i], &converted);
777     if (!succ) {
778       MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
779     }
780     args_spec.push_back(ArgsToAbstract(converted));
781   }
782   resource->set_args_spec(args_spec);
783   executor_info->arg_list_size = size;
784   executor_info->resource = resource;
785   info_[phase] = executor_info;
786   pip->Run(phase);
787 
788   // Save the compiled graph to MsPipeLine.
789   SaveCompiledGraph(phase);
790 
791   opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes();
792   abstract::AnalysisContext::ClearContext();
793   // Reclaim all resource used by optimizer.
794   ReclaimOptimizer();
795   resource->Clean();
796 
797   MS_LOG(INFO) << "Finish compiling.";
798   return true;
799 }
800 
FilterActions(const std::vector<ActionItem> & actions,const std::string & phase)801 std::vector<ActionItem> GraphExecutorPy::FilterActions(const std::vector<ActionItem> &actions,
802                                                        const std::string &phase) {
803   // filter action after validate when 'export'.
804   if (GetPhasePrefix(phase).rfind("export", 0) == std::string::npos) {
805     return actions;
806   }
807   MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'";
808   std::vector<ActionItem> filtered_actions;
809   for (const auto &item : actions) {
810     filtered_actions.emplace_back(item);
811     if (item.first == "validate") {
812       break;
813     }
814   }
815   return filtered_actions;
816 }
817 
ReleaseResource(const py::object & phase)818 void GraphExecutorPy::ReleaseResource(const py::object &phase) {
819   ResourcePtr res = GetResource(py::cast<std::string>(phase));
820   if (res != nullptr) {
821     res->Clean();
822   }
823   // Reclaim all resource used by optimizer;
824   ReclaimOptimizer();
825 }
826 
Compile(const py::object & source_obj,const py::tuple & args,const py::object & phase,bool use_vm,const std::string & queue_name)827 bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase, bool use_vm,
828                               const std::string &queue_name) {
829   bool ret_value = false;
830   try {
831     ret_value = CompileInner(source_obj, args, phase, use_vm, queue_name);
832   } catch (const py::error_already_set &ex) {
833     if (!StaticAnalysisException::Instance().HasException()) {
834       // print function call stack info before release
835       std::string exception_info = GetCompileExceptionInfo();
836       if (!exception_info.empty()) {
837         MS_LOG(ERROR) << exception_info;
838       }
839     }
840     ReleaseResource(phase);
841 
842     // re-throw this exception to Python interpreter to handle it
843     throw(py::error_already_set(ex));
844   } catch (const py::type_error &ex) {
845     ReleaseResource(phase);
846     throw py::type_error(ex);
847   } catch (const py::value_error &ex) {
848     ReleaseResource(phase);
849     throw py::value_error(ex);
850   } catch (const py::index_error &ex) {
851     ReleaseResource(phase);
852     throw py::index_error(ex);
853   } catch (const py::key_error &ex) {
854     ReleaseResource(phase);
855     throw py::key_error(ex);
856   } catch (const py::attribute_error &ex) {
857     ReleaseResource(phase);
858     throw py::attribute_error(ex);
859   } catch (const py::name_error &ex) {
860     ReleaseResource(phase);
861     throw py::name_error(ex);
862   } catch (const std::exception &ex) {
863     ReleaseResource(phase);
864     // re-throw this exception to Python interpreter to handle it
865     throw(std::runtime_error(ex.what()));
866   } catch (...) {
867     ReleaseResource(phase);
868     std::string exName(abi::__cxa_current_exception_type()->name());
869     MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
870   }
871   return ret_value;
872 }
873 
CacheValidateFuncGraph(const std::string & phase,const ResourcePtr & resource)874 void CacheValidateFuncGraph(const std::string &phase, const ResourcePtr &resource) {
875   if (IsPhaseTrain(phase) && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_COMPILE_CACHE)) {
876 #ifdef ENABLE_PROFILE
877     double t1 = GetTime();
878 #endif
879     CacheFuncGraph(resource);
880 #ifdef ENABLE_PROFILE
881     double t2 = GetTime();
882     MsProfile::StatTime("SaveCacheFuncGraph", t2 - t1);
883 #endif
884   }
885 }
886 
Run(const std::string & phase)887 void Pipeline::Run(const std::string &phase) {
888   MS_LOG(INFO) << "Pipeline run";
889   MS_EXCEPTION_IF_NULL(resource_);
890   FuncGraphPtr user_graph = nullptr;
891 
892   WITH(MsProfile::GetProfile())[&user_graph, &phase, this]() {
893     size_t i = 0;
894     for (auto &action : actions_) {
895 #ifdef ENABLE_TIMELINE
896       DumpTime &dump_time = DumpTime::GetInstance();
897       dump_time.Record(action.first, GetTime(), true);
898 #endif
899       bool result = true;
900       WITH(MsProfile::GetProfile()->Step(action.first))[&result, &action, this]() {
901         MS_LOG(DEBUG) << "Action " << action.first << " start ...";
902         result = action.second(resource_);
903         MS_LOG(DEBUG) << "Action " << action.first << " end.";
904       };
905       if (action.first == "task_emit") {
906         SetLoopCount(resource_);
907       } else if (action.first == "validate") {
908         CacheValidateFuncGraph(phase, resource_);
909       }
910       if (!result) {
911         MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
912       }
913 
914       FuncGraphPtr graph = resource_->func_graph();
915 #ifdef ENABLE_DUMP_IR
916       if (mindspore::RecorderManager::Instance().RdrEnable()) {
917         PipelineRDRProcess(graph, actions_, action, i);
918       }
919       if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && graph != nullptr) {
920         user_graph = graph;
921         std::string base_name = GetBaseNameForIR(SizeToLong(i), action.first);
922 
923         // generate IR file in dot format, which can be converted to svg file using graphviz dot command
924         draw::Draw(base_name + ".dot", graph);
925         // generate IR file in human readable format
926         if (i == actions_.size() - 1) {
927           DumpIR(base_name + ".ir", graph, false, kWholeStack);
928         } else {
929           DumpIR(base_name + ".ir", graph, false, kTopStack);
930         }
931         // generate IR file in a heavily commented format, which can also be reloaded
932         ExportIR(base_name + ".dat", graph);
933       }
934 #endif
935       i++;
936 #ifdef ENABLE_TIMELINE
937       dump_time.Record(action.first, GetTime(), false);
938 #endif
939     }
940   };
941 #ifdef ENABLE_PROFILE
942   MsProfile::Print();
943   MsProfile::Reset();
944 #endif
945 
946 #ifdef ENABLE_DUMP_IR
947   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) {
948     draw::DrawUserFuncGraph("ModelDigraph.dot", user_graph);
949   }
950 #endif
951   MS_LOG(INFO) << "End";
952 }
953 
ProcessVmArgInner(const py::tuple & args,const ResourcePtr & res,VectorRef * const arg_list)954 void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
955   MS_EXCEPTION_IF_NULL(arg_list);
956   std::size_t size = args.size();
957   bool arg_list_inited = !arg_list->empty();
958   for (std::size_t i = 0; i < size; i++) {
959     py::object arg = args[i];
960     auto ms_context = MsContext::GetInstance();
961     if (ms_context->backend_policy() == kMsConvert && py::isinstance<py::array>(arg)) {
962       MS_LOG(EXCEPTION) << "The " << i << "th arg is numpy array, not tensor.";
963     }
964     ValuePtr converted = nullptr;
965     bool succ = parse::ConvertData(arg, &converted);
966     if (!succ) {
967       MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
968     }
969     if (!arg_list_inited) {
970       arg_list->push_back(converted);
971       continue;
972     }
973     if (i >= arg_list->size()) {
974       MS_LOG(EXCEPTION) << "i:" << i << " output of range:" << arg_list->size();
975     }
976     (*arg_list)[i] = converted;
977   }
978 
979   MS_EXCEPTION_IF_NULL(res);
980   auto graph = res->func_graph();
981   MS_EXCEPTION_IF_NULL(graph);
982   std::vector<AnfNodePtr> graph_params = graph->parameters();
983   std::size_t graph_params_size = graph_params.size();
984   if ((*arg_list).size() != graph_params_size) {
985     // maybe some default parameter
986     for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
987       MS_EXCEPTION_IF_NULL(graph_params[i]);
988       auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
989       MS_EXCEPTION_IF_NULL(param_ptr);
990       if (!param_ptr->has_default()) {
991         MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
992       }
993       if (!param_ptr->default_param()->isa<Tensor>()) {
994         MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString()
995                           << "] is not initialized, need to call `.init_data()`";
996       }
997       arg_list->push_back(param_ptr->default_param());
998     }
999   }
1000 }
1001 
ProcessVmArg(const py::tuple & args,const std::string & phase,VectorRef * const arg_list)1002 void GraphExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) {
1003   ProcessVmArgInner(args, GetResource(phase), arg_list);
1004 }
1005 
1006 #ifdef ENABLE_DEBUGGER
TerminateDebugger()1007 void GraphExecutorPy::TerminateDebugger() {
1008   if (debugger_terminate_) {
1009     MS_LOG(INFO) << "Terminate debugger and clear resources!";
1010     ClearResAtexit();
1011     if (exit_success_) {
1012       exit(0);
1013     } else {
1014       exit(1);
1015     }
1016   }
1017 }
1018 #endif
1019 
Run(const py::tuple & args,const py::object & phase_obj)1020 py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) {
1021   // Mindspore debugger notify main thread to exit after one step, and will not run next step
1022 #ifdef ENABLE_DEBUGGER
1023   TerminateDebugger();
1024 #endif
1025   std::size_t size = args.size();
1026   if (!py::isinstance<py::str>(phase_obj)) {
1027     MS_LOG(EXCEPTION) << "Run failed, phase input is not a str";
1028   }
1029   auto phase = py::cast<std::string>(phase_obj);
1030   std::string backend = MsContext::GetInstance()->backend_policy();
1031 #ifdef ENABLE_GE
1032   if (backend == "ge") {
1033     return ExecDFGraph(info_, args, phase);
1034   }
1035 #else
1036   auto ret_val = std::make_shared<py::object>();
1037   if (info_.count(phase) != 0 && info_[phase]->func_graph != nullptr) {
1038     if (IsGraphOutputValueNodeOrParameter(info_[phase]->func_graph->output(), args, ret_val)) {
1039       // Check the input arg must be Tensor when backend is "ms".
1040       if (MsContext::GetInstance()->backend_policy() == kMsConvert) {
1041         for (std::size_t i = 0; i < size; i++) {
1042           ValuePtr converted = nullptr;
1043           if (!parse::ConvertData(args[i], &converted)) {
1044             MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
1045           }
1046         }
1047       }
1048       return *ret_val;
1049     }
1050   }
1051   if (backend == "ge") {
1052     // Virtual output constructed for test cases.
1053     if (!args.empty()) {
1054       return args[0];
1055     }
1056     return args;
1057   }
1058 #endif
1059   auto iter = info_.find(phase);
1060   if (iter == info_.end()) {
1061     MS_LOG(EXCEPTION) << "No executor info. found for phase: " << phase;
1062   }
1063   auto &execute_info = iter->second;
1064   MS_EXCEPTION_IF_NULL(execute_info);
1065   if (size > execute_info->arg_list_size) {
1066     MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << execute_info->arg_list_size;
1067   }
1068   ProcessVmArg(args, phase, &execute_info->arg_list);
1069   // Start to run phase.
1070   compile::VmEvalFuncPtr run = GetVmEvalFunc(phase);
1071   if (run == nullptr) {
1072     MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase;
1073   }
1074   // Set loopsink size for each phase.
1075   bool vm_loop_flag = info_[phase]->resource->vm_loop_flag();
1076   int64_t loop_size = info_[phase]->resource->loop_size();
1077   int64_t vm_loop = 1;
1078   if (vm_loop_flag) {
1079     vm_loop = loop_size;
1080   } else {
1081     // Set the loop size in config if graphs nums is 1(is_loop_sin=True), then there will be a loop embrace
1082     // 'Execute(graph)' in GPUSession.
1083     ConfigManager::GetInstance().set_gpu_loopsink_size(loop_size);
1084   }
1085   MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
1086   py::object ret;
1087   MS_LOG(DEBUG) << "Eval run" << backend;
1088   for (int64_t i = 0; i < vm_loop; i++) {
1089     BaseRef value = (*run)(execute_info->arg_list);
1090     ret = BaseRefToPyData(value);
1091   }
1092   MS_LOG(DEBUG) << "Run end";
1093   return ret;
1094 }
1095 
BuildGraph(const py::dict & init_params,const std::string & phase,const py::object & broadcast_params)1096 FuncGraphPtr GraphExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase,
1097                                          const py::object &broadcast_params) {
1098 #if ((defined ENABLE_GE) || (defined ENABLE_D))
1099   return BuildDFGraph(info_, init_params, phase, broadcast_params);
1100 #else
1101   return nullptr;
1102 #endif
1103 }
1104 
UpdataParamNodeDefaultInput(const std::string & phase,const std::unordered_map<std::string,tensor::TensorPtr> & params_value)1105 void GraphExecutorPy::UpdataParamNodeDefaultInput(
1106   const std::string &phase, const std::unordered_map<std::string, tensor::TensorPtr> &params_value) {
1107   FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1108   MS_EXCEPTION_IF_NULL(func_graph);
1109   MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase
1110                 << ")!";
1111   auto &params = func_graph->parameters();
1112   for (const auto &param : params) {
1113     MS_EXCEPTION_IF_NULL(param);
1114     auto param_cast = param->cast<ParameterPtr>();
1115     MS_EXCEPTION_IF_NULL(param_cast);
1116     auto iter = params_value.find(param_cast->name());
1117     if (iter != params_value.end()) {
1118       param_cast->set_default_param(iter->second);
1119     }
1120   }
1121 }
1122 
RunInitGraph(const py::dict & init_params,const std::string & phase) const1123 void GraphExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) const {
1124 #ifdef ENABLE_GE
1125   RunGEInitGraph(init_params, phase);
1126 #endif
1127 }
1128 
PyExePath(const py::object & py_exe_path)1129 void GraphExecutorPy::PyExePath(const py::object &py_exe_path) {
1130   if (!py::isinstance<py::str>(py_exe_path)) {
1131     MS_LOG(EXCEPTION) << "Failed, py_exe_path input is not a str";
1132   }
1133   auto py_exe_path_s = py::cast<std::string>(py_exe_path);
1134   auto ms_context = MsContext::GetInstance();
1135   ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
1136 }
1137 
KernelBuildServerDir(const py::object & kernel_build_server_dir)1138 void GraphExecutorPy::KernelBuildServerDir(const py::object &kernel_build_server_dir) {
1139   if (!py::isinstance<py::str>(kernel_build_server_dir)) {
1140     MS_LOG(EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
1141   }
1142   auto kernel_build_server_dir_s = py::cast<std::string>(kernel_build_server_dir);
1143   auto ms_context = MsContext::GetInstance();
1144   ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
1145 }
1146 
InitExecDataset(const std::string & queue_name,int64_t iter_num,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,const std::string & phase,bool need_run)1147 bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
1148                      const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
1149                      const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run) {
1150   std::string name = MsContext::GetInstance()->backend_policy();
1151 #ifndef NO_DLIB
1152   auto ms_context = MsContext::GetInstance();
1153   MS_EXCEPTION_IF_NULL(ms_context);
1154   if (!context::IsTsdOpened(ms_context) || !context::IsGeInited(ms_context)) {
1155     InitPipeline();
1156   }
1157 #endif
1158   if (iter_num == -1) {
1159     iter_num = INT32_MAX;
1160   }
1161   if (name == kMsConvert || name == kMsVm) {
1162     return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
1163   }
1164 #ifdef ENABLE_GE
1165   return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase);
1166 #else
1167   std::string backend = MsContext::GetInstance()->backend_policy();
1168   if (backend == "ge") {
1169     return true;
1170   }
1171 #endif
1172   return false;
1173 }
1174 
InitExecDatasetVm(const std::string & queue_name,int64_t size,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,bool need_run)1175 bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
1176                        const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
1177                        const std::vector<int64_t> &input_indexes, bool need_run) {
1178 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1179   if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
1180     return true;
1181   }
1182 #endif
1183   MS_LOG(INFO) << "Start InitDataSet Entry";
1184   mindspore::parse::python_adapter::set_python_env_flag(true);
1185   ShapeVector int_input_indexes;
1186   (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
1187                        [](int64_t item) { return static_cast<int64_t>(item); });
1188   std::vector<ShapeVector> int_shapes;
1189   (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes),
1190                        [](const std::vector<int64_t> &item) {
1191                          ShapeVector vector_item;
1192                          (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item),
1193                                               [](int64_t inner_item) { return static_cast<int64_t>(inner_item); });
1194                          return vector_item;
1195                        });
1196   auto p_init = std::make_shared<Primitive>("InitDataSetQueue");
1197   p_init->set_attr("queue_name", MakeValue(queue_name));
1198   p_init->set_attr("size", MakeValue(static_cast<int64_t>(size)));
1199   p_init->set_attr("batch_size", MakeValue(static_cast<int64_t>(batch_size)));
1200   p_init->set_attr("types", MakeValue(types));
1201   p_init->set_attr("shapes", MakeValue(int_shapes));
1202   p_init->set_attr("input_indexes", MakeValue(int_input_indexes));
1203 
1204   const std::vector<std::string> empty_str_list;
1205   p_init->set_attr("input_names", MakeValue(empty_str_list));
1206   p_init->set_attr("output_names", MakeValue(empty_str_list));
1207 
1208   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
1209   auto app_init = std::make_shared<CNode>(AnfNodePtrList{NewValueNode(p_init)}, func_graph);
1210   func_graph->set_output(app_init);
1211   auto manager = MakeManager();
1212   manager->AddFuncGraph(func_graph);
1213 
1214   // AbstractNone indicates there is no output for this apply node.
1215   auto abstract_none = std::make_shared<abstract::AbstractNone>();
1216   app_init->set_abstract(abstract_none);
1217   // Before the graph compiling, need reset the iter num.
1218   ConfigManager::GetInstance().ResetIterNum();
1219 #ifdef ENABLE_DUMP_IR
1220   mindspore::RDR::ResetRecorder();
1221 #endif
1222 
1223   compile::SetMindRTEnable();
1224   auto backend = compile::CreateBackend();
1225   MS_EXCEPTION_IF_NULL(backend);
1226   // The data set graph compiling and running of mindRT.
1227   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1228     const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
1229     MS_EXCEPTION_IF_NULL(mindrt_backend);
1230     auto &actor_info = mindrt_backend->CompileGraphs(func_graph);
1231     VectorRef args;
1232     if (need_run) {
1233       VectorRef outputs;
1234       mindrt_backend->RunGraph(actor_info, args, &outputs);
1235     }
1236     ConfigManager::GetInstance().set_iter_num(size);
1237     return true;
1238   }
1239 
1240   auto convert_fn = backend->convert_fn();
1241   MS_EXCEPTION_IF_NULL(convert_fn);
1242   // Convert CNodeList to LinConvertResult.
1243   auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
1244   auto runner = convert_fn(segment, "");
1245   ConfigManager::GetInstance().set_iter_num(size);
1246   // PS cache does not support loop sink.
1247 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1248   if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
1249     ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
1250     ConfigManager::GetInstance().set_iter_num(1);
1251   }
1252 #endif
1253 
1254   if (!(*runner.run)) {
1255     // empty function
1256     MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset.";
1257   }
1258 
1259   // launch init dataset runner without inputs and outputs
1260   VectorRef args;
1261   auto fn = runner.run;
1262   if (need_run) {
1263     (void)(*fn)(args);
1264   }
1265   MS_LOG(DEBUG) << "InitDataSetVm End.";
1266   return true;
1267 }  // namespace pipeline
1268 
ResetOpId()1269 void ResetOpId() { mindspore::id_generator::reset_id(); }
1270 
InitHccl()1271 void InitHccl() {
1272 #ifdef ENABLE_GE
1273   (void)InitPipeline();
1274 #else
1275   mindspore::parse::python_adapter::set_python_env_flag(true);
1276   auto ms_context = MsContext::GetInstance();
1277   MS_EXCEPTION_IF_NULL(ms_context);
1278   uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1279 #if ENABLE_D
1280   bool task_sink = true;
1281   auto single_op = common::GetEnv(kGraphOpRun);
1282   auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
1283   if (single_op == "1" || enable_mem_scheduler == "1") {
1284     task_sink = false;
1285   }
1286   auto mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
1287   if (!task_sink && mode == kGraphMode) {
1288     MS_LOG(INFO) << "mpi collective init.";
1289     if (!HcclCollectiveGroup::instance().InitCollective()) {
1290       MS_LOG(EXCEPTION) << "Mpi init failed, please check if mpirun is used correctly.";
1291     }
1292     device_id = IntToUint(HcclCollectiveGroup::instance().GetDeviceId());
1293     ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
1294     ms_context->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
1295   }
1296 #endif
1297   std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1298   ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
1299   if (ms_context->backend_policy() == "ms" &&
1300       ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
1301     auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
1302     MS_EXCEPTION_IF_NULL(runtime_instance);
1303 #ifndef ENABLE_SECURITY
1304     runtime_instance->PreInit();
1305 #endif
1306     (void)context::OpenTsd(ms_context);
1307     if (!runtime_instance->Init()) {
1308       MS_LOG(EXCEPTION) << "Runtime init failed.";
1309     }
1310   } else {
1311     (void)context::OpenTsd(ms_context);
1312   }
1313 #endif
1314 #if (defined ENABLE_D)
1315 #ifndef ENABLE_SECURITY
1316   if (!ProfilingManager::GetInstance().IsProfiling()) {
1317     ProfilingManager::GetInstance().SetHcclEnabledBefProfilingEnabled();
1318   }
1319 #endif
1320 #endif
1321 }
1322 
FinalizeHccl()1323 void FinalizeHccl() {
1324 #ifdef ENABLE_GE
1325   (void)FinalizeBackend();
1326 #else
1327   session::ExecutorManager::Instance().Clear();
1328   device::KernelRuntimeManager::Instance().ClearRuntimeResource();
1329 #endif
1330 }
1331 
GetHcclRankId()1332 uint32_t GetHcclRankId() {
1333   uint32_t rank_id = 0;
1334   bool ret = CommManager::GetInstance().GetRankID("", &rank_id);
1335   if (!ret) {
1336     MS_LOG(ERROR) << "Get rank id failed, return rank id " << rank_id << " as default.";
1337   }
1338   return rank_id;
1339 }
1340 
GetHcclRankSize()1341 uint32_t GetHcclRankSize() {
1342   uint32_t rank_size = 0;
1343   bool ret = CommManager::GetInstance().GetRankSize("", &rank_size);
1344   if (!ret) {
1345     MS_LOG(ERROR) << "Get rank size failed, return rank size " << rank_size << " as default.";
1346   }
1347   return rank_size;
1348 }
1349 
ExportGraph(const std::string & file_name,const std::string &,const std::string & phase)1350 void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) {
1351 #if ((defined ENABLE_GE) || (defined ENABLE_D))
1352   ExportDFGraph(file_name, phase);
1353 #else
1354   MS_EXCEPTION(ValueError) << "Only support export file in 'AIR' format with Ascend backend.";
1355 #endif
1356 }
1357 
LoadMindIR(const std::string & file_name,char * dec_key,const size_t key_len,const std::string & dec_mode)1358 FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len,
1359                         const std::string &dec_mode) {
1360   auto func_graph =
1361     mindspore::LoadMindIR(file_name, false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode);
1362 #ifdef ENABLE_DUMP_IR
1363   auto context_ptr = MsContext::GetInstance();
1364   MS_EXCEPTION_IF_NULL(context_ptr);
1365   bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
1366   if (save_graphs) {
1367     DumpIR("load.ir", func_graph);
1368   }
1369 #endif
1370   return func_graph;
1371 }
1372 
ReleaseGeTsd()1373 void ReleaseGeTsd() {
1374   auto context_ptr = MsContext::GetInstance();
1375   if (context_ptr != nullptr) {
1376     (void)context::FinalizeGe(context_ptr, true);
1377     (void)context::CloseTsd(context_ptr, true);
1378   }
1379 }
1380 
1381 #ifndef ENABLE_SECURITY
StartUpProfiling()1382 void StartUpProfiling() {
1383 #ifdef ENABLE_D
1384   if (!ProfilingManager::GetInstance().IsProfiling()) {
1385     return;
1386   }
1387 
1388   auto ms_context = MsContext::GetInstance();
1389   MS_EXCEPTION_IF_NULL(ms_context);
1390 
1391   MS_LOG(INFO) << "Startup profiling";
1392   // Start up profiling before OpenTsd
1393   uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1394   std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1395   if (ms_context->backend_policy() == "ms" &&
1396       ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
1397     auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
1398     MS_EXCEPTION_IF_NULL(runtime_instance);
1399     runtime_instance->PreInit();
1400   }
1401 #endif
1402 }
1403 #endif
1404 
InitPipeline()1405 void InitPipeline() {
1406   // set python env flag
1407   mindspore::parse::python_adapter::set_python_env_flag(true);
1408 #ifndef ENABLE_SECURITY
1409   // Startup profiling before open tsd
1410   StartUpProfiling();
1411 #endif
1412   // open tsd before ge initialize
1413   auto ms_context = MsContext::GetInstance();
1414   MS_EXCEPTION_IF_NULL(ms_context);
1415   if (!context::OpenTsd(ms_context)) {
1416     MS_LOG(EXCEPTION) << "Open tsd failed";
1417   }
1418   (void)context::InitGe(ms_context);
1419 }
1420 
FinalizeBackend()1421 void FinalizeBackend() {
1422   auto context_ptr = MsContext::GetInstance();
1423   MS_EXCEPTION_IF_NULL(context_ptr);
1424   (void)context::FinalizeGe(context_ptr);
1425   (void)context::CloseTsd(context_ptr);
1426 }
1427 
ClearResAtexit()1428 void ClearResAtexit() {
1429   MS_LOG(DEBUG) << "Pipeline clear all resource";
1430 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1431   if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
1432     if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
1433       ps::ps_cache_instance.Finalize();
1434     }
1435     MS_LOG(INFO) << "Start finalizing worker.";
1436     const std::string &server_mode = ps::PSContext::instance()->server_mode();
1437     if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
1438       fl::worker::FLWorker::GetInstance().Finalize();
1439     } else {
1440       ps::Worker::GetInstance().Finalize();
1441     }
1442   }
1443 #endif
1444 #ifdef ENABLE_DUMP_IR
1445   mindspore::RDR::ResetRecorder();
1446 #endif
1447   session::ExecutorManager::Instance().Clear();
1448   device::KernelRuntimeManager::Instance().ClearRuntimeResource();
1449   runtime::GraphScheduler::GetInstance().Clear();
1450   device::DeviceContextManager::GetInstance().ClearDeviceContexts();
1451   ad::g_k_prims.clear();
1452   ad::ClearKPynativeCellStaticRes();
1453   ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
1454 
1455   abstract::ClearPrimEvaluatorMap();
1456   pipeline::GetMethodMap().clear();
1457   pipeline::GetAttrMap().clear();
1458   pipeline::GraphExecutorPy::ClearRes();
1459   pipeline::ReclaimOptimizer();
1460   pynative::PynativeExecutor::GetInstance()->ClearRes();
1461   opt::python_pass::PyPassManager::GetInstance()->ClearRes();
1462 #ifdef ENABLE_GE
1463   transform::DfGraphManager::GetInstance().ClearGraph();
1464   transform::OpAdapterMap::get().clear();
1465 #else
1466   ConfigManager::GetInstance().ResetIterNum();
1467 #endif
1468   ReleaseGeTsd();
1469   parse::python_adapter::ResetPythonScope();
1470   abstract::AnalysisResultCacheMgr::GetInstance().Clear();
1471   abstract::AnalysisContext::ClearContext();
1472   abstract::AnalysisSchedule::GetInstance().Stop();
1473 #ifdef ENABLE_DEBUGGER
1474   Debugger::GetInstance()->Reset();
1475 #endif
1476   g_args_cache.clear();
1477   // clean static variable to prevent from crash. As static variable is released after
1478   // Python threads is released.
1479   parse::data_converter::ClearObjectCache();
1480   parse::Parser::CleanParserResource();
1481   parse::CleanDataClassToClassMap();
1482   trace::ClearTraceStack();
1483 }
1484 
PyEncrypt(char * plain_data,size_t plain_len,char * key,size_t key_len,const std::string & enc_mode)1485 py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode) {
1486   size_t encrypt_len;
1487   auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
1488                                          reinterpret_cast<Byte *>(key), key_len, enc_mode);
1489   if (encrypt_data == nullptr) {
1490     MS_EXCEPTION(ValueError) << "Encrypt failed";
1491   }
1492   auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
1493   return py_encrypt_data;
1494 }
1495 
PyDecrypt(const std::string & encrypt_data_path,char * key,size_t key_len,const std::string & dec_mode)1496 py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode) {
1497   size_t decrypt_len;
1498   auto decrypt_data =
1499     mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
1500   if (decrypt_data == nullptr) {
1501     MS_LOG(ERROR) << "Decrypt failed";
1502     return py::none();
1503   }
1504   auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
1505   return py_decrypt_data;
1506 }
1507 
PyIsCipherFile(const std::string & file_path)1508 bool PyIsCipherFile(const std::string &file_path) { return mindspore::IsCipherFile(file_path); }
1509 }  // namespace pipeline
1510 }  // namespace mindspore
1511