1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/pipeline.h"
20
21 #include <memory>
22 #include <sstream>
23 #include <map>
24 #include <unordered_map>
25 #include <cstdlib>
26 #include <algorithm>
27 #include <iomanip>
28
29 #include "ir/param_info.h"
30 #include "pipeline/jit/pass.h"
31 #include "pipeline/jit/parse/data_converter.h"
32 #include "frontend/optimizer/ad/dfunctor.h"
33 #include "pipeline/jit/static_analysis/async_eval_result.h"
34 #include "debug/anf_ir_dump.h"
35 #include "debug/dump_proto.h"
36 #include "debug/anf_ir_utils.h"
37 #include "debug/common.h"
38 #include "utils/config_manager.h"
39 #include "utils/convert_utils.h"
40 #include "utils/convert_utils_py.h"
41 #include "utils/context/context_extends.h"
42 #include "vm/segment_runner.h"
43 #include "frontend/parallel/context.h"
44 #include "frontend/parallel/graph_util/get_parallel_info.h"
45 #include "runtime/device/kernel_runtime_manager.h"
46 #include "backend/session/executor_manager.h"
47 #include "debug/trace.h"
48 #include "debug/draw.h"
49 #include "pipeline/pynative/pynative_execute.h"
50 #include "frontend/optimizer/py_pass_manager.h"
51 #include "pybind_api/pybind_patch.h"
52 #include "utils/shape_utils.h"
53 #include "utils/info.h"
54 #include "load_mindir/load_model.h"
55 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
56 #include "runtime/hardware/device_context_manager.h"
57 #include "utils/crypto.h"
58 #include "utils/comm_manager.h"
59 #if ((defined ENABLE_CPU) && (!defined _WIN32))
60 #include "ps/constants.h"
61 #include "ps/util.h"
62 #include "ps/worker.h"
63 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
64 #include "ps/ps_cache/ps_cache_manager.h"
65 #include "fl/server/server.h"
66 #include "fl/worker/fl_worker.h"
67 #endif
68
69 #if ((defined ENABLE_GE) || (defined ENABLE_D))
70 #include "pipeline/jit/pipeline_ge.h"
71 #include "transform/graph_ir/convert.h"
72 #include "transform/graph_ir/df_graph_manager.h"
73 #include "transform/graph_ir/op_adapter_map.h"
74 #include "runtime/device/ascend/profiling/profiling_manager.h"
75 #include "runtime/device/ascend/distribute/ascend_collective.h"
76 #endif
77 #ifdef ENABLE_DUMP_IR
78 #include "debug/rdr/running_data_recorder.h"
79 #include "debug/rdr/recorder_manager.h"
80 #endif
81
82 namespace mindspore {
83 // namespace to support intermediate representation definition
84 namespace pipeline {
85 using Tensor = mindspore::tensor::Tensor;
86 using MetaTensor = mindspore::tensor::MetaTensor;
87 using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
88 using mindspore::abstract::AbstractTensor;
89 using mindspore::abstract::AbstractTensorPtr;
90 using mindspore::abstract::AbstractTuple;
91 using mindspore::abstract::AbstractTuplePtr;
92
93 #ifdef ENABLE_D
94 #ifndef ENABLE_SECURITY
95 using mindspore::device::ascend::ProfilingManager;
96 #endif
97 using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
98 #endif
99
100 const char IR_TYPE_ANF[] = "anf_ir";
101 const char IR_TYPE_ONNX[] = "onnx_ir";
102 const char IR_TYPE_MINDIR[] = "mind_ir";
103
104 GraphExecutorPyPtr GraphExecutorPy::executor_ = nullptr;
105 std::mutex GraphExecutorPy::instance_lock_;
106 #ifdef ENABLE_DEBUGGER
107 bool GraphExecutorPy::debugger_terminate_ = false;
108 bool GraphExecutorPy::exit_success_ = false;
109 #endif
110
111 std::unordered_map<abstract::AbstractBasePtrList, uint64_t, abstract::AbstractBasePtrListHasher,
112 abstract::AbstractBasePtrListEqual>
113 g_args_cache;
114
115 namespace {
116 constexpr char kCompileCacheFilePath[] = "compile_cache.mindir";
117 #ifdef ENABLE_DUMP_IR
GetBaseNameForIR(int64_t stage_idx,const std::string & action_name)118 std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
119 std::ostringstream oss;
120 int spaces = 2;
121 oss << std::setfill('0') << std::setw(spaces) << stage_idx << "_" << action_name;
122 return oss.str();
123 }
124 #endif
125
ArgsToAbstract(const ValuePtr & value)126 AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
127 MS_EXCEPTION_IF_NULL(value);
128 bool broaden = value->isa<MetaTensor>() ||
129 (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
130
131 return abstract::FromValue(value, broaden);
132 }
133
CheckArgValid(const py::handle & arg)134 bool CheckArgValid(const py::handle &arg) {
135 if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
136 auto vector_arg = py::cast<py::list>(arg);
137 return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid);
138 }
139
140 if (py::isinstance<py::dict>(arg)) {
141 auto dict_arg = py::cast<py::dict>(arg);
142 return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
143 }
144
145 return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
146 py::isinstance<Number>(arg) || (py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
147 }
148
GetCompileExceptionInfo()149 std::string GetCompileExceptionInfo() {
150 std::ostringstream oss;
151 trace::GetTraceStackInfo(oss);
152 return oss.str();
153 }
154
SetLoopCount(const ResourcePtr & resource)155 void SetLoopCount(const ResourcePtr &resource) {
156 MS_EXCEPTION_IF_NULL(resource);
157 auto func_graph = resource->func_graph();
158 if (func_graph != nullptr && func_graph->manager() != nullptr) {
159 auto manager = func_graph->manager();
160 size_t graph_nums = manager->func_graphs().size();
161 int64_t loop_size = ConfigManager::GetInstance().iter_num();
162 const auto context_ptr = MsContext::GetInstance();
163 if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
164 resource->set_vm_loop(!context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK), loop_size);
165 } else if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
166 bool run_with_mind_rt = graph_nums == 1 || context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT);
167 resource->set_vm_loop(!run_with_mind_rt, loop_size);
168 }
169 MS_LOG(INFO) << "Change vm_loop_flag to " << resource->vm_loop_flag() << ", set loop_size to " << loop_size;
170 }
171 }
172
GetCachedFuncGraph(const ResourcePtr & resource,const std::string & queue_name)173 void GetCachedFuncGraph(const ResourcePtr &resource, const std::string &queue_name) {
174 MS_EXCEPTION_IF_NULL(resource);
175 auto realpath = Common::CreatePrefixPath(kCompileCacheFilePath);
176 if (!realpath.has_value()) {
177 MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
178 }
179 std::ifstream f(realpath.value());
180 bool cache_file_existed = f.good();
181 f.close();
182 if (!cache_file_existed) {
183 MS_LOG(WARNING) << "The compilation cache file '" << realpath.value()
184 << "' dose not exist. Execute all the compilation actions.";
185 return;
186 }
187 MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only.";
188 FuncGraphPtr fg = mindspore::LoadMindIR(realpath.value());
189 if (fg == nullptr) {
190 MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value();
191 }
192 FuncGraphManagerPtr mng = fg->manager();
193 if (mng == nullptr) {
194 auto res_mng = resource->manager();
195 MS_EXCEPTION_IF_NULL(res_mng);
196 res_mng->AddFuncGraph(fg);
197 fg->set_manager(res_mng);
198 }
199 auto cnodes = fg->GetOrderedCnodes();
200 for (auto cnode : cnodes) {
201 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
202 if (prim != nullptr && prim->HasAttr("shared_name")) {
203 prim->set_attr("shared_name", MakeValue(queue_name));
204 break;
205 }
206 }
207 resource->set_func_graph(fg);
208 }
209
CacheFuncGraph(const ResourcePtr & resource)210 void CacheFuncGraph(const ResourcePtr &resource) {
211 MS_EXCEPTION_IF_NULL(resource);
212 auto realpath = Common::CreatePrefixPath(kCompileCacheFilePath);
213 if (!realpath.has_value()) {
214 MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
215 }
216
217 ChangeFileMode(realpath.value(), S_IRWXU);
218 std::ofstream fout(realpath.value());
219 if (!fout.is_open()) {
220 MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
221 }
222 FuncGraphPtr fg = resource->func_graph();
223 mind_ir::ModelProto fg_model = GetBinaryProto(fg, true);
224 if (!fg_model.SerializeToOstream(&fout)) {
225 MS_LOG(EXCEPTION) << "Failed to cache the graph to file " << realpath.value();
226 }
227 fout.close();
228 ChangeFileMode(realpath.value(), S_IRUSR);
229 }
230 } // namespace
231
CheckArgsValid(const py::tuple & args)232 void CheckArgsValid(const py::tuple &args) {
233 for (size_t i = 0; i < args.size(); i++) {
234 if (!CheckArgValid(args[i])) {
235 MS_EXCEPTION(TypeError)
236 << "The inputs types of the outermost network support bool, int, float, None, tensor, "
237 "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
238 "and tuple or list containing only these types, and dict whose values are these types, but the "
239 << i << "th arg type is " << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
240 }
241 }
242 }
243
GenerateArgumentsKey(const std::unordered_map<std::string,py::object> & args)244 py::object GenerateArgumentsKey(const std::unordered_map<std::string, py::object> &args) {
245 MS_LOG(DEBUG) << "GenerateArgumentsKey args size:" << args.size();
246 abstract::AbstractBasePtrList args_spec;
247
248 for (const auto &arg : args) {
249 if (py::isinstance<py::module>(arg.second)) {
250 MS_LOG(EXCEPTION) << "GenerateArgumentsKey failed, argument input should not be py::module";
251 }
252 ValuePtr converted = nullptr;
253 if (!parse::ConvertData(arg.second, &converted)) {
254 MS_LOG(EXCEPTION) << "GenerateArgumentsKey convert arg failed";
255 }
256 args_spec.push_back(ArgsToAbstract(converted));
257 }
258
259 uint64_t key;
260 auto iter = g_args_cache.find(args_spec);
261 if (iter == g_args_cache.end()) {
262 static uint64_t key_counter = 0;
263 key = key_counter;
264 ++key_counter;
265 g_args_cache[args_spec] = key;
266 MS_LOG(INFO) << "Generate a new compile key for new args, key: " << key;
267 } else {
268 key = iter->second;
269 }
270 return py::int_(key);
271 }
272
VerifyInputSignature(const py::list & input_signature,const py::tuple & inputs)273 py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs) {
274 MS_LOG(DEBUG) << "Verify args size:" << inputs.size();
275 if (inputs.size() != input_signature.size()) {
276 MS_LOG(ERROR) << "Signature size not equal to args size";
277 return false;
278 }
279
280 size_t count = 0;
281 for (auto arg_obj : inputs) {
282 if (py::isinstance<Tensor>(arg_obj)) {
283 MS_LOG(DEBUG) << "Verify Tensor";
284 auto m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
285 if (m_tensor == nullptr) {
286 MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
287 return false;
288 }
289 auto sig = input_signature[count].cast<std::shared_ptr<MetaTensor>>();
290 ShapeVector sig_shape = sig->shape();
291 TypePtr sig_type = sig->Dtype();
292
293 ShapeVector tensor_shape = m_tensor->shape_c();
294 if (tensor_shape != sig_shape) {
295 MS_LOG(ERROR) << "Python input shape is incompatible with input_signature";
296 return false;
297 }
298
299 if (*m_tensor->Dtype() != *sig_type) {
300 MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature("
301 << sig_type->ToString() << ")";
302 return false;
303 }
304 }
305 count++;
306 }
307
308 return true;
309 }
310
PipelineRDRProcess(const FuncGraphPtr & graph,const std::vector<ActionItem> & actions,const ActionItem & action,size_t i)311 void PipelineRDRProcess(const FuncGraphPtr &graph, const std::vector<ActionItem> &actions, const ActionItem &action,
312 size_t i) {
313 MS_LOG(INFO) << "Recording FuncGraph in pipeline using RDR.";
314 std::string name = GetBaseNameForIR(SizeToLong(i), action.first);
315 if (graph != nullptr) {
316 auto graph_clone = BasicClone(graph);
317 if (graph_clone != nullptr) {
318 DumpGraphParams dump_params = {false, static_cast<int>(kTopStack)};
319 if (i == actions.size()) {
320 dump_params.dump_mode = static_cast<int>(kWholeStack);
321 }
322 (void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph_clone, dump_params, ".ir");
323 } else {
324 MS_LOG(WARNING) << "Clone FuncGraph failed in pipeline, no FuncGraph recording in RDR.";
325 }
326 } else {
327 MS_LOG(WARNING) << "Pipeline Resource has no FuncGraph, no FuncGraph recording in RDR";
328 }
329 MS_LOG(INFO) << "Recording FuncGraph in pipeline end.";
330 }
331
GraphExecutorPy()332 GraphExecutorPy::GraphExecutorPy() {}
333
GetResource(const std::string & phase)334 ResourcePtr GraphExecutorPy::GetResource(const std::string &phase) {
335 MS_LOG(DEBUG) << "Phase size:" << info_.size();
336 if (info_.count(phase) == 0) {
337 return nullptr;
338 }
339 return info_[phase]->resource;
340 }
341
GetFuncGraph(const std::string & phase)342 FuncGraphPtr GraphExecutorPy::GetFuncGraph(const std::string &phase) {
343 if (info_.count(phase) == 0) {
344 MS_LOG(EXCEPTION) << "No executor info. found for phase: " << phase;
345 }
346 return info_[phase]->func_graph;
347 }
348
GetGradGraph(const std::string & phase)349 FuncGraphPtr GraphExecutorPy::GetGradGraph(const std::string &phase) {
350 if (phase.empty()) {
351 MS_LOG(EXCEPTION) << "The input phase is empty.";
352 }
353 if (info_.count(phase) == 0) {
354 MS_LOG(EXCEPTION) << "No phase in executor:" << phase;
355 }
356
357 auto execute_info = info_[phase];
358 MS_EXCEPTION_IF_NULL(execute_info);
359 auto grad_graph = execute_info->grad_graph;
360 MS_EXCEPTION_IF_NULL(grad_graph);
361 return grad_graph;
362 }
363
SetGradGraph(const FuncGraphPtr & grad_graph,const std::string & phase)364 void GraphExecutorPy::SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase) {
365 if (phase.empty()) {
366 MS_LOG(EXCEPTION) << "The input phase is empty.";
367 }
368 if (info_.count(phase) == 0) {
369 MS_LOG(EXCEPTION) << "No phase in executor: " << phase;
370 }
371
372 auto execute_info = info_[phase];
373 MS_EXCEPTION_IF_NULL(execute_info);
374 if (execute_info->grad_graph != nullptr) {
375 MS_LOG(DEBUG) << "The grad graph has existed, phase is: " << phase;
376 }
377 MS_EXCEPTION_IF_NULL(grad_graph);
378 execute_info->grad_graph = grad_graph;
379 }
380
GetVmEvalFunc(const std::string & phase)381 compile::VmEvalFuncPtr GraphExecutorPy::GetVmEvalFunc(const std::string &phase) {
382 ResourcePtr res = GetResource(phase);
383 MS_EXCEPTION_IF_NULL(res);
384 if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is<compile::VmEvalFuncPtr>()) {
385 return res->results()[kOutput].cast<compile::VmEvalFuncPtr>();
386 }
387 MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput;
388 return nullptr;
389 }
390
HasCompiled(const std::string & phase) const391 bool GraphExecutorPy::HasCompiled(const std::string &phase) const {
392 if (info_.count(phase) == 0) {
393 return false;
394 }
395 return true;
396 }
397
GetFuncGraphProto(const std::string & phase,const std::string & ir_type)398 py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) {
399 FuncGraphPtr fg_ptr = GetFuncGraph(phase);
400 if (fg_ptr == nullptr) {
401 for (auto &item : info_) {
402 MS_LOG(DEBUG) << "Phase key is: " << item.first;
403 }
404 MS_LOG(EXCEPTION) << "Can not find func graph " << phase;
405 }
406
407 if (ir_type == IR_TYPE_ANF) {
408 std::string proto_str = GetFuncGraphProtoString(fg_ptr);
409 if (proto_str.empty()) {
410 MS_LOG(EXCEPTION) << "Export ANF format model failed.";
411 }
412 return proto_str;
413 }
414
415 if (ir_type == IR_TYPE_ONNX) {
416 std::string proto_str = GetOnnxProtoString(fg_ptr);
417 if (proto_str.empty()) {
418 MS_LOG(EXCEPTION) << "Export ONNX format model failed.";
419 }
420 return proto_str;
421 }
422
423 if (ir_type == IR_TYPE_MINDIR) {
424 std::string proto_str = GetBinaryProtoString(fg_ptr);
425 if (proto_str.empty()) {
426 MS_LOG(EXCEPTION) << "Export MINDIR format model failed.";
427 }
428 return proto_str;
429 }
430
431 MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
432 }
433
GetParameterLayout(const std::string & phase)434 py::dict GraphExecutorPy::GetParameterLayout(const std::string &phase) {
435 MS_LOG(DEBUG) << "GetParameterLayout!";
436 std::string layout_graph = phase + kStepParallelGraph;
437 auto graph = GetFuncGraph(layout_graph);
438 return mindspore::parallel::GetParameterLayout(graph);
439 }
440
GetCNodeStrategy(const std::string & phase)441 py::dict GraphExecutorPy::GetCNodeStrategy(const std::string &phase) {
442 MS_LOG(DEBUG) << "GetCNodeStrategy!";
443 return stra_dict_[phase];
444 }
445
GetParallelParameterNameList(const std::string & phase)446 py::list GraphExecutorPy::GetParallelParameterNameList(const std::string &phase) {
447 std::string param_graph = phase + kStepParallelGraph;
448 auto graph = GetFuncGraph(param_graph);
449 return mindspore::parallel::GetParallelParameterNameList(graph);
450 }
451
SetCNodeStrategy(const std::string & name,const parallel::Strategys & strategy)452 void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
453 MS_LOG(DEBUG) << "SetCNodeStrategy!";
454 stra_dict_[phase_][py::str(name)] = strategy;
455 }
456
GetNumOpsInfo(const std::string & phase)457 size_t GraphExecutorPy::GetNumOpsInfo(const std::string &phase) {
458 MS_LOG(DEBUG) << "GetNumOpsInfo!";
459 return phase_to_num_op_info_[phase];
460 }
461
SetNumOpsInfo(size_t num_ops)462 void GraphExecutorPy::SetNumOpsInfo(size_t num_ops) {
463 MS_LOG(DEBUG) << "SetNumOpsInfo!";
464 phase_to_num_op_info_[phase_] = num_ops;
465 }
466
GetAllreduceFusion(const std::string & phase)467 py::dict GraphExecutorPy::GetAllreduceFusion(const std::string &phase) {
468 MS_LOG(INFO) << "GetAllreduceFusion!";
469 auto graph = GetFuncGraph(phase);
470 return mindspore::parallel::GetAllreduceFusion(graph);
471 }
472
473 // Not support multi thread, not support nested call too.
474 // Here using nested_called flg to avoid nested call.
DelNetRes(const std::string & id)475 void GraphExecutorPy::DelNetRes(const std::string &id) {
476 static bool nested_called = false;
477 if (nested_called) {
478 return;
479 }
480 nested_called = true;
481 #ifdef ENABLE_GE
482 FinalizeBackend();
483 #else
484 ConfigManager::GetInstance().ResetIterNum();
485 #endif
486 if (executor_ != nullptr) {
487 bool flag = false;
488 auto tmp_info = info_;
489 for (auto &item : tmp_info) {
490 if (item.first.find(id) != string::npos) {
491 MS_LOG(DEBUG) << "Delete network res:" << item.first;
492 item.second = nullptr;
493 (void)info_.erase(item.first);
494 flag = true;
495 }
496 }
497
498 MS_LOG(DEBUG) << "Delete flag:" << flag;
499 #ifdef ENABLE_GE
500 if (flag && info_.size() == 0) {
501 // because Ge only support one Session exist at the same time ,so we delete the old one
502 transform::DfGraphManager::GetInstance().DeleteGraphRunner();
503 transform::DfGraphManager::GetInstance().EraseAnfGraph();
504 transform::DfGraphManager::GetInstance().DeleteGeSession();
505 }
506 #endif
507 }
508 nested_called = false;
509 }
510
ClearRes()511 void GraphExecutorPy::ClearRes() {
512 MS_LOG(INFO) << "Clean executor resource!";
513 executor_ = nullptr;
514 }
515
~GraphExecutorPy()516 GraphExecutorPy::~GraphExecutorPy() {
517 MS_LOG(INFO) << "Release Executor!";
518 ConfigManager::GetInstance().ResetConfig();
519 }
520
GetWeightInfo(const CNodePtr & root_node,const AnfNodePtr & weight_node,std::map<std::string,std::pair<PrimitivePyAdapterPtr,std::string>> * fake_quant_table)521 void GraphExecutorPy::GetWeightInfo(
522 const CNodePtr &root_node, const AnfNodePtr &weight_node,
523 std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) {
524 MS_EXCEPTION_IF_NULL(root_node);
525 MS_EXCEPTION_IF_NULL(fake_quant_table);
526 std::string weight_name;
527 auto x = root_node->input(1);
528 MS_EXCEPTION_IF_NULL(x);
529 if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
530 weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
531 } else {
532 auto para = weight_node->cast<ParameterPtr>();
533 MS_EXCEPTION_IF_NULL(para);
534 weight_name = para->name();
535 }
536 // find the fakequant from input
537 int64_t count = 0;
538 const int64_t max_depth = 5;
539 CNodePtr cnode = nullptr;
540 auto is_quant_cnode = [](const AnfNodePtr &node) {
541 return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
542 IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) ||
543 IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerLayer) ||
544 IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerChannel);
545 };
546 while (!is_quant_cnode(x)) {
547 if (count >= max_depth) {
548 break;
549 }
550 cnode = x->cast<CNodePtr>();
551 if (cnode == nullptr || cnode->size() <= 1) {
552 break;
553 }
554 x = cnode->input(1);
555 count += 1;
556 }
557 if (x->isa<Parameter>() || IsPrimitiveCNode(x, prim::kPrimLoad)) {
558 (*fake_quant_table)[weight_name] = std::make_pair(nullptr, "input");
559 }
560 // get the fakequant parameter minq's name
561 if (!is_quant_cnode(x)) {
562 return;
563 }
564 cnode = x->cast<CNodePtr>();
565 constexpr size_t expect_input_size = 4;
566 if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != expect_input_size) {
567 return;
568 }
569 const size_t fakequant_index = 2;
570 auto fakequant_min_node = cnode->input(fakequant_index);
571 if (!fakequant_min_node->isa<Parameter>() && !IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
572 return;
573 }
574 std::string fakequant_min_node_name;
575 if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
576 fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
577 } else {
578 auto param = fakequant_min_node->cast<ParameterPtr>();
579 MS_EXCEPTION_IF_NULL(param);
580 fakequant_min_node_name = param->name();
581 }
582 auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
583 MS_EXCEPTION_IF_NULL(quant_op_value);
584 if (!quant_op_value->isa<PrimitivePy>()) {
585 return;
586 }
587 auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
588 (*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name);
589 }
590
FetchInfoForQuantExport(const std::string & phase)591 std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> GraphExecutorPy::FetchInfoForQuantExport(
592 const std::string &phase) {
593 FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
594 MS_EXCEPTION_IF_NULL(func_graph);
595 MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase << ")!";
596 std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> fake_quant_table;
597 auto filter = [](const AnfNodePtr &node) {
598 return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
599 IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
600 };
601 std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
602 auto is_quant_cnode = [](const AnfNodePtr &node) {
603 return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
604 IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) ||
605 IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerLayer) ||
606 IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerChannel);
607 };
608 const size_t root_node_size = 3;
609 const size_t weight_index = 2;
610 for (const auto &node : nodes) {
611 auto root_node = node->cast<CNodePtr>();
612 if (root_node == nullptr || root_node->size() != root_node_size) {
613 continue;
614 }
615 auto weight = root_node->input(weight_index);
616 if (!is_quant_cnode(weight)) {
617 auto tuple_node = weight->cast<CNodePtr>();
618 if (tuple_node != nullptr) {
619 auto fake_node = tuple_node->input(1);
620 if (!is_quant_cnode(fake_node)) {
621 continue;
622 } else {
623 weight = fake_node;
624 }
625 }
626 }
627 // get parameter weight's name
628 auto cnode = weight->cast<CNodePtr>();
629 MS_EXCEPTION_IF_NULL(cnode);
630 auto weight_node = cnode->input(weight_index);
631 if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
632 continue;
633 }
634 GetWeightInfo(root_node, weight_node, &fake_quant_table);
635 }
636 return fake_quant_table;
637 }
638
SaveCompiledGraph(const std::string & phase)639 void GraphExecutorPy::SaveCompiledGraph(const std::string &phase) {
640 // save the graph to GraphExecutorPy
641 FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
642 MS_EXCEPTION_IF_NULL(func_graph);
643 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
644 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
645
646 MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase << ")!";
647 info_[phase]->func_graph = func_graph;
648 if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) &&
649 ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
650 MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
651 func_graph = info_[phase]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();
652 ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
653 std::string layout_graph = phase + kStepParallelGraph;
654 executor_info->func_graph = func_graph;
655 info_[layout_graph] = executor_info;
656 } else {
657 MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!";
658 }
659 MS_LOG(INFO) << "End save compiled func graph!";
660 }
661
GetGeBackendPolicy() const662 void GraphExecutorPy::GetGeBackendPolicy() const {
663 auto ms_context = MsContext::GetInstance();
664 MS_EXCEPTION_IF_NULL(ms_context);
665 std::string backend = ms_context->backend_policy();
666 if (backend != "ge") {
667 MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!";
668 }
669 }
670
IsPhaseExportAir(const std::string & phase)671 bool IsPhaseExportAir(const std::string &phase) {
672 auto phase_to_export = "export.air";
673 return phase.rfind(phase_to_export) != std::string::npos;
674 }
675
IsPhaseTrain(const std::string & phase)676 bool IsPhaseTrain(const std::string &phase) {
677 const std::string phase_to_train = "train";
678 return phase.rfind(phase_to_train) != std::string::npos;
679 }
680
IsPhaseLoadFromMindIR(const std::string & phase)681 bool IsPhaseLoadFromMindIR(const std::string &phase) {
682 const std::string mindir_graph = "graph_load_from_mindir";
683 return phase.rfind(mindir_graph) != std::string::npos;
684 }
685
GetPipeline(const ResourcePtr & resource,const std::string & phase,bool use_vm)686 std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::string &phase, bool use_vm) {
687 MS_EXCEPTION_IF_NULL(resource);
688 bool is_air = IsPhaseExportAir(phase);
689
690 std::string backend = MsContext::GetInstance()->backend_policy();
691
692 #if ((defined ENABLE_CPU) && (!defined _WIN32))
693 const std::string &server_mode = ps::PSContext::instance()->server_mode();
694 if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
695 ps::PSContext::instance()->is_server()) {
696 return ServerPipeline();
697 }
698 if (ps::PSContext::instance()->is_server()) {
699 resource->results()[kBackend] = compile::CreateBackend();
700 return PServerPipeline();
701 }
702 if (ps::PSContext::instance()->is_scheduler()) {
703 return PSchedulerPipeline();
704 }
705 #endif
706
707 if (use_vm && backend != "ge" && !is_air) {
708 compile::SetMindRTEnable();
709 // Create backend.
710 auto backend_ptr = compile::CreateBackend();
711 #ifdef ENABLE_DEBUGGER
712 // Connect session to debugger
713 backend_ptr->SetDebugger();
714 #endif
715 resource->results()[kBackend] = backend_ptr;
716 // If the 'use_frontend_compile_cache' context has been set true and the cache is read successfully,
717 // do the backend actions only.
718 if (IsPhaseTrain(phase) && MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE) &&
719 resource->func_graph() != nullptr) {
720 return BackendPipeline();
721 }
722 if (IsPhaseLoadFromMindIR(phase)) {
723 return MindIRPipeline();
724 }
725 return VmPipeline();
726 }
727 return GePipeline();
728 }
729
CompileInner(const py::object & source_obj,const py::tuple & args,const py::object & phase_obj,bool use_vm,const std::string & queue_name)730 bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj,
731 bool use_vm, const std::string &queue_name) {
732 // Check if the phase is valid.
733 if ((!py::isinstance<py::str>(phase_obj))) {
734 MS_LOG(ERROR) << "The `phase` must be string.";
735 return false;
736 }
737 // Check if the function or net is valid.
738 if (py::isinstance<py::none>(source_obj)) {
739 MS_LOG(ERROR) << "The source object to compile should not be None.";
740 return false;
741 }
742 // Check if the args of function or net is valid.
743 CheckArgsValid(args);
744
745 auto phase = py::cast<std::string>(phase_obj);
746 MS_LOG(INFO) << "Start compiling, phase: " << phase << ".";
747 MS_LOG(DEBUG) << "Compiling source: {" << py::str(source_obj)
748 << "}\n\n Args: " << py::str(const_cast<py::tuple &>(args));
749
750 #ifdef ENABLE_GE
751 GetGeBackendPolicy();
752 #endif
753 ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
754 ResourcePtr resource = std::make_shared<Resource>(source_obj);
755
756 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE)) {
757 #ifdef ENABLE_PROFILE
758 double t1 = GetTime();
759 #endif
760 GetCachedFuncGraph(resource, queue_name);
761 #ifdef ENABLE_PROFILE
762 double t2 = GetTime();
763 MsProfile::StatTime("LoadCachedFuncGraph", t2 - t1);
764 #endif
765 }
766
767 phase_ = phase;
768 auto actions = GetPipeline(resource, phase, use_vm);
769 std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(actions, phase));
770
771 // Get the parameters items and add the value to args_spec.
772 abstract::AbstractBasePtrList args_spec;
773 std::size_t size = args.size();
774 for (std::size_t i = 0; i < size; i++) {
775 ValuePtr converted = nullptr;
776 bool succ = parse::ConvertData(args[i], &converted);
777 if (!succ) {
778 MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
779 }
780 args_spec.push_back(ArgsToAbstract(converted));
781 }
782 resource->set_args_spec(args_spec);
783 executor_info->arg_list_size = size;
784 executor_info->resource = resource;
785 info_[phase] = executor_info;
786 pip->Run(phase);
787
788 // Save the compiled graph to MsPipeLine.
789 SaveCompiledGraph(phase);
790
791 opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes();
792 abstract::AnalysisContext::ClearContext();
793 // Reclaim all resource used by optimizer.
794 ReclaimOptimizer();
795 resource->Clean();
796
797 MS_LOG(INFO) << "Finish compiling.";
798 return true;
799 }
800
FilterActions(const std::vector<ActionItem> & actions,const std::string & phase)801 std::vector<ActionItem> GraphExecutorPy::FilterActions(const std::vector<ActionItem> &actions,
802 const std::string &phase) {
803 // filter action after validate when 'export'.
804 if (GetPhasePrefix(phase).rfind("export", 0) == std::string::npos) {
805 return actions;
806 }
807 MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'";
808 std::vector<ActionItem> filtered_actions;
809 for (const auto &item : actions) {
810 filtered_actions.emplace_back(item);
811 if (item.first == "validate") {
812 break;
813 }
814 }
815 return filtered_actions;
816 }
817
ReleaseResource(const py::object & phase)818 void GraphExecutorPy::ReleaseResource(const py::object &phase) {
819 ResourcePtr res = GetResource(py::cast<std::string>(phase));
820 if (res != nullptr) {
821 res->Clean();
822 }
823 // Reclaim all resource used by optimizer;
824 ReclaimOptimizer();
825 }
826
Compile(const py::object & source_obj,const py::tuple & args,const py::object & phase,bool use_vm,const std::string & queue_name)827 bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase, bool use_vm,
828 const std::string &queue_name) {
829 bool ret_value = false;
830 try {
831 ret_value = CompileInner(source_obj, args, phase, use_vm, queue_name);
832 } catch (const py::error_already_set &ex) {
833 if (!StaticAnalysisException::Instance().HasException()) {
834 // print function call stack info before release
835 std::string exception_info = GetCompileExceptionInfo();
836 if (!exception_info.empty()) {
837 MS_LOG(ERROR) << exception_info;
838 }
839 }
840 ReleaseResource(phase);
841
842 // re-throw this exception to Python interpreter to handle it
843 throw(py::error_already_set(ex));
844 } catch (const py::type_error &ex) {
845 ReleaseResource(phase);
846 throw py::type_error(ex);
847 } catch (const py::value_error &ex) {
848 ReleaseResource(phase);
849 throw py::value_error(ex);
850 } catch (const py::index_error &ex) {
851 ReleaseResource(phase);
852 throw py::index_error(ex);
853 } catch (const py::key_error &ex) {
854 ReleaseResource(phase);
855 throw py::key_error(ex);
856 } catch (const py::attribute_error &ex) {
857 ReleaseResource(phase);
858 throw py::attribute_error(ex);
859 } catch (const py::name_error &ex) {
860 ReleaseResource(phase);
861 throw py::name_error(ex);
862 } catch (const std::exception &ex) {
863 ReleaseResource(phase);
864 // re-throw this exception to Python interpreter to handle it
865 throw(std::runtime_error(ex.what()));
866 } catch (...) {
867 ReleaseResource(phase);
868 std::string exName(abi::__cxa_current_exception_type()->name());
869 MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
870 }
871 return ret_value;
872 }
873
CacheValidateFuncGraph(const std::string & phase,const ResourcePtr & resource)874 void CacheValidateFuncGraph(const std::string &phase, const ResourcePtr &resource) {
875 if (IsPhaseTrain(phase) && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_COMPILE_CACHE)) {
876 #ifdef ENABLE_PROFILE
877 double t1 = GetTime();
878 #endif
879 CacheFuncGraph(resource);
880 #ifdef ENABLE_PROFILE
881 double t2 = GetTime();
882 MsProfile::StatTime("SaveCacheFuncGraph", t2 - t1);
883 #endif
884 }
885 }
886
Run(const std::string & phase)887 void Pipeline::Run(const std::string &phase) {
888 MS_LOG(INFO) << "Pipeline run";
889 MS_EXCEPTION_IF_NULL(resource_);
890 FuncGraphPtr user_graph = nullptr;
891
892 WITH(MsProfile::GetProfile())[&user_graph, &phase, this]() {
893 size_t i = 0;
894 for (auto &action : actions_) {
895 #ifdef ENABLE_TIMELINE
896 DumpTime &dump_time = DumpTime::GetInstance();
897 dump_time.Record(action.first, GetTime(), true);
898 #endif
899 bool result = true;
900 WITH(MsProfile::GetProfile()->Step(action.first))[&result, &action, this]() {
901 MS_LOG(DEBUG) << "Action " << action.first << " start ...";
902 result = action.second(resource_);
903 MS_LOG(DEBUG) << "Action " << action.first << " end.";
904 };
905 if (action.first == "task_emit") {
906 SetLoopCount(resource_);
907 } else if (action.first == "validate") {
908 CacheValidateFuncGraph(phase, resource_);
909 }
910 if (!result) {
911 MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
912 }
913
914 FuncGraphPtr graph = resource_->func_graph();
915 #ifdef ENABLE_DUMP_IR
916 if (mindspore::RecorderManager::Instance().RdrEnable()) {
917 PipelineRDRProcess(graph, actions_, action, i);
918 }
919 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && graph != nullptr) {
920 user_graph = graph;
921 std::string base_name = GetBaseNameForIR(SizeToLong(i), action.first);
922
923 // generate IR file in dot format, which can be converted to svg file using graphviz dot command
924 draw::Draw(base_name + ".dot", graph);
925 // generate IR file in human readable format
926 if (i == actions_.size() - 1) {
927 DumpIR(base_name + ".ir", graph, false, kWholeStack);
928 } else {
929 DumpIR(base_name + ".ir", graph, false, kTopStack);
930 }
931 // generate IR file in a heavily commented format, which can also be reloaded
932 ExportIR(base_name + ".dat", graph);
933 }
934 #endif
935 i++;
936 #ifdef ENABLE_TIMELINE
937 dump_time.Record(action.first, GetTime(), false);
938 #endif
939 }
940 };
941 #ifdef ENABLE_PROFILE
942 MsProfile::Print();
943 MsProfile::Reset();
944 #endif
945
946 #ifdef ENABLE_DUMP_IR
947 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) {
948 draw::DrawUserFuncGraph("ModelDigraph.dot", user_graph);
949 }
950 #endif
951 MS_LOG(INFO) << "End";
952 }
953
ProcessVmArgInner(const py::tuple & args,const ResourcePtr & res,VectorRef * const arg_list)954 void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
955 MS_EXCEPTION_IF_NULL(arg_list);
956 std::size_t size = args.size();
957 bool arg_list_inited = !arg_list->empty();
958 for (std::size_t i = 0; i < size; i++) {
959 py::object arg = args[i];
960 auto ms_context = MsContext::GetInstance();
961 if (ms_context->backend_policy() == kMsConvert && py::isinstance<py::array>(arg)) {
962 MS_LOG(EXCEPTION) << "The " << i << "th arg is numpy array, not tensor.";
963 }
964 ValuePtr converted = nullptr;
965 bool succ = parse::ConvertData(arg, &converted);
966 if (!succ) {
967 MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
968 }
969 if (!arg_list_inited) {
970 arg_list->push_back(converted);
971 continue;
972 }
973 if (i >= arg_list->size()) {
974 MS_LOG(EXCEPTION) << "i:" << i << " output of range:" << arg_list->size();
975 }
976 (*arg_list)[i] = converted;
977 }
978
979 MS_EXCEPTION_IF_NULL(res);
980 auto graph = res->func_graph();
981 MS_EXCEPTION_IF_NULL(graph);
982 std::vector<AnfNodePtr> graph_params = graph->parameters();
983 std::size_t graph_params_size = graph_params.size();
984 if ((*arg_list).size() != graph_params_size) {
985 // maybe some default parameter
986 for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
987 MS_EXCEPTION_IF_NULL(graph_params[i]);
988 auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
989 MS_EXCEPTION_IF_NULL(param_ptr);
990 if (!param_ptr->has_default()) {
991 MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
992 }
993 if (!param_ptr->default_param()->isa<Tensor>()) {
994 MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString()
995 << "] is not initialized, need to call `.init_data()`";
996 }
997 arg_list->push_back(param_ptr->default_param());
998 }
999 }
1000 }
1001
ProcessVmArg(const py::tuple & args,const std::string & phase,VectorRef * const arg_list)1002 void GraphExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) {
1003 ProcessVmArgInner(args, GetResource(phase), arg_list);
1004 }
1005
1006 #ifdef ENABLE_DEBUGGER
TerminateDebugger()1007 void GraphExecutorPy::TerminateDebugger() {
1008 if (debugger_terminate_) {
1009 MS_LOG(INFO) << "Terminate debugger and clear resources!";
1010 ClearResAtexit();
1011 if (exit_success_) {
1012 exit(0);
1013 } else {
1014 exit(1);
1015 }
1016 }
1017 }
1018 #endif
1019
Run(const py::tuple & args,const py::object & phase_obj)1020 py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) {
1021 // Mindspore debugger notify main thread to exit after one step, and will not run next step
1022 #ifdef ENABLE_DEBUGGER
1023 TerminateDebugger();
1024 #endif
1025 std::size_t size = args.size();
1026 if (!py::isinstance<py::str>(phase_obj)) {
1027 MS_LOG(EXCEPTION) << "Run failed, phase input is not a str";
1028 }
1029 auto phase = py::cast<std::string>(phase_obj);
1030 std::string backend = MsContext::GetInstance()->backend_policy();
1031 #ifdef ENABLE_GE
1032 if (backend == "ge") {
1033 return ExecDFGraph(info_, args, phase);
1034 }
1035 #else
1036 auto ret_val = std::make_shared<py::object>();
1037 if (info_.count(phase) != 0 && info_[phase]->func_graph != nullptr) {
1038 if (IsGraphOutputValueNodeOrParameter(info_[phase]->func_graph->output(), args, ret_val)) {
1039 // Check the input arg must be Tensor when backend is "ms".
1040 if (MsContext::GetInstance()->backend_policy() == kMsConvert) {
1041 for (std::size_t i = 0; i < size; i++) {
1042 ValuePtr converted = nullptr;
1043 if (!parse::ConvertData(args[i], &converted)) {
1044 MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
1045 }
1046 }
1047 }
1048 return *ret_val;
1049 }
1050 }
1051 if (backend == "ge") {
1052 // Virtual output constructed for test cases.
1053 if (!args.empty()) {
1054 return args[0];
1055 }
1056 return args;
1057 }
1058 #endif
1059 auto iter = info_.find(phase);
1060 if (iter == info_.end()) {
1061 MS_LOG(EXCEPTION) << "No executor info. found for phase: " << phase;
1062 }
1063 auto &execute_info = iter->second;
1064 MS_EXCEPTION_IF_NULL(execute_info);
1065 if (size > execute_info->arg_list_size) {
1066 MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << execute_info->arg_list_size;
1067 }
1068 ProcessVmArg(args, phase, &execute_info->arg_list);
1069 // Start to run phase.
1070 compile::VmEvalFuncPtr run = GetVmEvalFunc(phase);
1071 if (run == nullptr) {
1072 MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase;
1073 }
1074 // Set loopsink size for each phase.
1075 bool vm_loop_flag = info_[phase]->resource->vm_loop_flag();
1076 int64_t loop_size = info_[phase]->resource->loop_size();
1077 int64_t vm_loop = 1;
1078 if (vm_loop_flag) {
1079 vm_loop = loop_size;
1080 } else {
1081 // Set the loop size in config if graphs nums is 1(is_loop_sin=True), then there will be a loop embrace
1082 // 'Execute(graph)' in GPUSession.
1083 ConfigManager::GetInstance().set_gpu_loopsink_size(loop_size);
1084 }
1085 MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
1086 py::object ret;
1087 MS_LOG(DEBUG) << "Eval run" << backend;
1088 for (int64_t i = 0; i < vm_loop; i++) {
1089 BaseRef value = (*run)(execute_info->arg_list);
1090 ret = BaseRefToPyData(value);
1091 }
1092 MS_LOG(DEBUG) << "Run end";
1093 return ret;
1094 }
1095
BuildGraph(const py::dict & init_params,const std::string & phase,const py::object & broadcast_params)1096 FuncGraphPtr GraphExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase,
1097 const py::object &broadcast_params) {
1098 #if ((defined ENABLE_GE) || (defined ENABLE_D))
1099 return BuildDFGraph(info_, init_params, phase, broadcast_params);
1100 #else
1101 return nullptr;
1102 #endif
1103 }
1104
UpdataParamNodeDefaultInput(const std::string & phase,const std::unordered_map<std::string,tensor::TensorPtr> & params_value)1105 void GraphExecutorPy::UpdataParamNodeDefaultInput(
1106 const std::string &phase, const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_value) {
1107 FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1108 MS_EXCEPTION_IF_NULL(func_graph);
1109 MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase
1110 << ")!";
1111 auto ¶ms = func_graph->parameters();
1112 for (const auto ¶m : params) {
1113 MS_EXCEPTION_IF_NULL(param);
1114 auto param_cast = param->cast<ParameterPtr>();
1115 MS_EXCEPTION_IF_NULL(param_cast);
1116 auto iter = params_value.find(param_cast->name());
1117 if (iter != params_value.end()) {
1118 param_cast->set_default_param(iter->second);
1119 }
1120 }
1121 }
1122
RunInitGraph(const py::dict & init_params,const std::string & phase) const1123 void GraphExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) const {
1124 #ifdef ENABLE_GE
1125 RunGEInitGraph(init_params, phase);
1126 #endif
1127 }
1128
PyExePath(const py::object & py_exe_path)1129 void GraphExecutorPy::PyExePath(const py::object &py_exe_path) {
1130 if (!py::isinstance<py::str>(py_exe_path)) {
1131 MS_LOG(EXCEPTION) << "Failed, py_exe_path input is not a str";
1132 }
1133 auto py_exe_path_s = py::cast<std::string>(py_exe_path);
1134 auto ms_context = MsContext::GetInstance();
1135 ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
1136 }
1137
KernelBuildServerDir(const py::object & kernel_build_server_dir)1138 void GraphExecutorPy::KernelBuildServerDir(const py::object &kernel_build_server_dir) {
1139 if (!py::isinstance<py::str>(kernel_build_server_dir)) {
1140 MS_LOG(EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
1141 }
1142 auto kernel_build_server_dir_s = py::cast<std::string>(kernel_build_server_dir);
1143 auto ms_context = MsContext::GetInstance();
1144 ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
1145 }
1146
InitExecDataset(const std::string & queue_name,int64_t iter_num,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,const std::string & phase,bool need_run)1147 bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
1148 const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
1149 const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run) {
1150 std::string name = MsContext::GetInstance()->backend_policy();
1151 #ifndef NO_DLIB
1152 auto ms_context = MsContext::GetInstance();
1153 MS_EXCEPTION_IF_NULL(ms_context);
1154 if (!context::IsTsdOpened(ms_context) || !context::IsGeInited(ms_context)) {
1155 InitPipeline();
1156 }
1157 #endif
1158 if (iter_num == -1) {
1159 iter_num = INT32_MAX;
1160 }
1161 if (name == kMsConvert || name == kMsVm) {
1162 return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
1163 }
1164 #ifdef ENABLE_GE
1165 return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase);
1166 #else
1167 std::string backend = MsContext::GetInstance()->backend_policy();
1168 if (backend == "ge") {
1169 return true;
1170 }
1171 #endif
1172 return false;
1173 }
1174
InitExecDatasetVm(const std::string & queue_name,int64_t size,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,bool need_run)1175 bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
1176 const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
1177 const std::vector<int64_t> &input_indexes, bool need_run) {
1178 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1179 if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
1180 return true;
1181 }
1182 #endif
1183 MS_LOG(INFO) << "Start InitDataSet Entry";
1184 mindspore::parse::python_adapter::set_python_env_flag(true);
1185 ShapeVector int_input_indexes;
1186 (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
1187 [](int64_t item) { return static_cast<int64_t>(item); });
1188 std::vector<ShapeVector> int_shapes;
1189 (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes),
1190 [](const std::vector<int64_t> &item) {
1191 ShapeVector vector_item;
1192 (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item),
1193 [](int64_t inner_item) { return static_cast<int64_t>(inner_item); });
1194 return vector_item;
1195 });
1196 auto p_init = std::make_shared<Primitive>("InitDataSetQueue");
1197 p_init->set_attr("queue_name", MakeValue(queue_name));
1198 p_init->set_attr("size", MakeValue(static_cast<int64_t>(size)));
1199 p_init->set_attr("batch_size", MakeValue(static_cast<int64_t>(batch_size)));
1200 p_init->set_attr("types", MakeValue(types));
1201 p_init->set_attr("shapes", MakeValue(int_shapes));
1202 p_init->set_attr("input_indexes", MakeValue(int_input_indexes));
1203
1204 const std::vector<std::string> empty_str_list;
1205 p_init->set_attr("input_names", MakeValue(empty_str_list));
1206 p_init->set_attr("output_names", MakeValue(empty_str_list));
1207
1208 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
1209 auto app_init = std::make_shared<CNode>(AnfNodePtrList{NewValueNode(p_init)}, func_graph);
1210 func_graph->set_output(app_init);
1211 auto manager = MakeManager();
1212 manager->AddFuncGraph(func_graph);
1213
1214 // AbstractNone indicates there is no output for this apply node.
1215 auto abstract_none = std::make_shared<abstract::AbstractNone>();
1216 app_init->set_abstract(abstract_none);
1217 // Before the graph compiling, need reset the iter num.
1218 ConfigManager::GetInstance().ResetIterNum();
1219 #ifdef ENABLE_DUMP_IR
1220 mindspore::RDR::ResetRecorder();
1221 #endif
1222
1223 compile::SetMindRTEnable();
1224 auto backend = compile::CreateBackend();
1225 MS_EXCEPTION_IF_NULL(backend);
1226 // The data set graph compiling and running of mindRT.
1227 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1228 const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
1229 MS_EXCEPTION_IF_NULL(mindrt_backend);
1230 auto &actor_info = mindrt_backend->CompileGraphs(func_graph);
1231 VectorRef args;
1232 if (need_run) {
1233 VectorRef outputs;
1234 mindrt_backend->RunGraph(actor_info, args, &outputs);
1235 }
1236 ConfigManager::GetInstance().set_iter_num(size);
1237 return true;
1238 }
1239
1240 auto convert_fn = backend->convert_fn();
1241 MS_EXCEPTION_IF_NULL(convert_fn);
1242 // Convert CNodeList to LinConvertResult.
1243 auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
1244 auto runner = convert_fn(segment, "");
1245 ConfigManager::GetInstance().set_iter_num(size);
1246 // PS cache does not support loop sink.
1247 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1248 if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
1249 ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
1250 ConfigManager::GetInstance().set_iter_num(1);
1251 }
1252 #endif
1253
1254 if (!(*runner.run)) {
1255 // empty function
1256 MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset.";
1257 }
1258
1259 // launch init dataset runner without inputs and outputs
1260 VectorRef args;
1261 auto fn = runner.run;
1262 if (need_run) {
1263 (void)(*fn)(args);
1264 }
1265 MS_LOG(DEBUG) << "InitDataSetVm End.";
1266 return true;
1267 } // namespace pipeline
1268
ResetOpId()1269 void ResetOpId() { mindspore::id_generator::reset_id(); }
1270
InitHccl()1271 void InitHccl() {
1272 #ifdef ENABLE_GE
1273 (void)InitPipeline();
1274 #else
1275 mindspore::parse::python_adapter::set_python_env_flag(true);
1276 auto ms_context = MsContext::GetInstance();
1277 MS_EXCEPTION_IF_NULL(ms_context);
1278 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1279 #if ENABLE_D
1280 bool task_sink = true;
1281 auto single_op = common::GetEnv(kGraphOpRun);
1282 auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
1283 if (single_op == "1" || enable_mem_scheduler == "1") {
1284 task_sink = false;
1285 }
1286 auto mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
1287 if (!task_sink && mode == kGraphMode) {
1288 MS_LOG(INFO) << "mpi collective init.";
1289 if (!HcclCollectiveGroup::instance().InitCollective()) {
1290 MS_LOG(EXCEPTION) << "Mpi init failed, please check if mpirun is used correctly.";
1291 }
1292 device_id = IntToUint(HcclCollectiveGroup::instance().GetDeviceId());
1293 ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
1294 ms_context->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
1295 }
1296 #endif
1297 std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1298 ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
1299 if (ms_context->backend_policy() == "ms" &&
1300 ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
1301 auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
1302 MS_EXCEPTION_IF_NULL(runtime_instance);
1303 #ifndef ENABLE_SECURITY
1304 runtime_instance->PreInit();
1305 #endif
1306 (void)context::OpenTsd(ms_context);
1307 if (!runtime_instance->Init()) {
1308 MS_LOG(EXCEPTION) << "Runtime init failed.";
1309 }
1310 } else {
1311 (void)context::OpenTsd(ms_context);
1312 }
1313 #endif
1314 #if (defined ENABLE_D)
1315 #ifndef ENABLE_SECURITY
1316 if (!ProfilingManager::GetInstance().IsProfiling()) {
1317 ProfilingManager::GetInstance().SetHcclEnabledBefProfilingEnabled();
1318 }
1319 #endif
1320 #endif
1321 }
1322
FinalizeHccl()1323 void FinalizeHccl() {
1324 #ifdef ENABLE_GE
1325 (void)FinalizeBackend();
1326 #else
1327 session::ExecutorManager::Instance().Clear();
1328 device::KernelRuntimeManager::Instance().ClearRuntimeResource();
1329 #endif
1330 }
1331
GetHcclRankId()1332 uint32_t GetHcclRankId() {
1333 uint32_t rank_id = 0;
1334 bool ret = CommManager::GetInstance().GetRankID("", &rank_id);
1335 if (!ret) {
1336 MS_LOG(ERROR) << "Get rank id failed, return rank id " << rank_id << " as default.";
1337 }
1338 return rank_id;
1339 }
1340
GetHcclRankSize()1341 uint32_t GetHcclRankSize() {
1342 uint32_t rank_size = 0;
1343 bool ret = CommManager::GetInstance().GetRankSize("", &rank_size);
1344 if (!ret) {
1345 MS_LOG(ERROR) << "Get rank size failed, return rank size " << rank_size << " as default.";
1346 }
1347 return rank_size;
1348 }
1349
ExportGraph(const std::string & file_name,const std::string &,const std::string & phase)1350 void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) {
1351 #if ((defined ENABLE_GE) || (defined ENABLE_D))
1352 ExportDFGraph(file_name, phase);
1353 #else
1354 MS_EXCEPTION(ValueError) << "Only support export file in 'AIR' format with Ascend backend.";
1355 #endif
1356 }
1357
LoadMindIR(const std::string & file_name,char * dec_key,const size_t key_len,const std::string & dec_mode)1358 FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len,
1359 const std::string &dec_mode) {
1360 auto func_graph =
1361 mindspore::LoadMindIR(file_name, false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode);
1362 #ifdef ENABLE_DUMP_IR
1363 auto context_ptr = MsContext::GetInstance();
1364 MS_EXCEPTION_IF_NULL(context_ptr);
1365 bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
1366 if (save_graphs) {
1367 DumpIR("load.ir", func_graph);
1368 }
1369 #endif
1370 return func_graph;
1371 }
1372
ReleaseGeTsd()1373 void ReleaseGeTsd() {
1374 auto context_ptr = MsContext::GetInstance();
1375 if (context_ptr != nullptr) {
1376 (void)context::FinalizeGe(context_ptr, true);
1377 (void)context::CloseTsd(context_ptr, true);
1378 }
1379 }
1380
1381 #ifndef ENABLE_SECURITY
StartUpProfiling()1382 void StartUpProfiling() {
1383 #ifdef ENABLE_D
1384 if (!ProfilingManager::GetInstance().IsProfiling()) {
1385 return;
1386 }
1387
1388 auto ms_context = MsContext::GetInstance();
1389 MS_EXCEPTION_IF_NULL(ms_context);
1390
1391 MS_LOG(INFO) << "Startup profiling";
1392 // Start up profiling before OpenTsd
1393 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1394 std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1395 if (ms_context->backend_policy() == "ms" &&
1396 ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
1397 auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
1398 MS_EXCEPTION_IF_NULL(runtime_instance);
1399 runtime_instance->PreInit();
1400 }
1401 #endif
1402 }
1403 #endif
1404
InitPipeline()1405 void InitPipeline() {
1406 // set python env flag
1407 mindspore::parse::python_adapter::set_python_env_flag(true);
1408 #ifndef ENABLE_SECURITY
1409 // Startup profiling before open tsd
1410 StartUpProfiling();
1411 #endif
1412 // open tsd before ge initialize
1413 auto ms_context = MsContext::GetInstance();
1414 MS_EXCEPTION_IF_NULL(ms_context);
1415 if (!context::OpenTsd(ms_context)) {
1416 MS_LOG(EXCEPTION) << "Open tsd failed";
1417 }
1418 (void)context::InitGe(ms_context);
1419 }
1420
FinalizeBackend()1421 void FinalizeBackend() {
1422 auto context_ptr = MsContext::GetInstance();
1423 MS_EXCEPTION_IF_NULL(context_ptr);
1424 (void)context::FinalizeGe(context_ptr);
1425 (void)context::CloseTsd(context_ptr);
1426 }
1427
ClearResAtexit()1428 void ClearResAtexit() {
1429 MS_LOG(DEBUG) << "Pipeline clear all resource";
1430 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1431 if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
1432 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
1433 ps::ps_cache_instance.Finalize();
1434 }
1435 MS_LOG(INFO) << "Start finalizing worker.";
1436 const std::string &server_mode = ps::PSContext::instance()->server_mode();
1437 if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
1438 fl::worker::FLWorker::GetInstance().Finalize();
1439 } else {
1440 ps::Worker::GetInstance().Finalize();
1441 }
1442 }
1443 #endif
1444 #ifdef ENABLE_DUMP_IR
1445 mindspore::RDR::ResetRecorder();
1446 #endif
1447 session::ExecutorManager::Instance().Clear();
1448 device::KernelRuntimeManager::Instance().ClearRuntimeResource();
1449 runtime::GraphScheduler::GetInstance().Clear();
1450 device::DeviceContextManager::GetInstance().ClearDeviceContexts();
1451 ad::g_k_prims.clear();
1452 ad::ClearKPynativeCellStaticRes();
1453 ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
1454
1455 abstract::ClearPrimEvaluatorMap();
1456 pipeline::GetMethodMap().clear();
1457 pipeline::GetAttrMap().clear();
1458 pipeline::GraphExecutorPy::ClearRes();
1459 pipeline::ReclaimOptimizer();
1460 pynative::PynativeExecutor::GetInstance()->ClearRes();
1461 opt::python_pass::PyPassManager::GetInstance()->ClearRes();
1462 #ifdef ENABLE_GE
1463 transform::DfGraphManager::GetInstance().ClearGraph();
1464 transform::OpAdapterMap::get().clear();
1465 #else
1466 ConfigManager::GetInstance().ResetIterNum();
1467 #endif
1468 ReleaseGeTsd();
1469 parse::python_adapter::ResetPythonScope();
1470 abstract::AnalysisResultCacheMgr::GetInstance().Clear();
1471 abstract::AnalysisContext::ClearContext();
1472 abstract::AnalysisSchedule::GetInstance().Stop();
1473 #ifdef ENABLE_DEBUGGER
1474 Debugger::GetInstance()->Reset();
1475 #endif
1476 g_args_cache.clear();
1477 // clean static variable to prevent from crash. As static variable is released after
1478 // Python threads is released.
1479 parse::data_converter::ClearObjectCache();
1480 parse::Parser::CleanParserResource();
1481 parse::CleanDataClassToClassMap();
1482 trace::ClearTraceStack();
1483 }
1484
PyEncrypt(char * plain_data,size_t plain_len,char * key,size_t key_len,const std::string & enc_mode)1485 py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode) {
1486 size_t encrypt_len;
1487 auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
1488 reinterpret_cast<Byte *>(key), key_len, enc_mode);
1489 if (encrypt_data == nullptr) {
1490 MS_EXCEPTION(ValueError) << "Encrypt failed";
1491 }
1492 auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
1493 return py_encrypt_data;
1494 }
1495
PyDecrypt(const std::string & encrypt_data_path,char * key,size_t key_len,const std::string & dec_mode)1496 py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode) {
1497 size_t decrypt_len;
1498 auto decrypt_data =
1499 mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
1500 if (decrypt_data == nullptr) {
1501 MS_LOG(ERROR) << "Decrypt failed";
1502 return py::none();
1503 }
1504 auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
1505 return py_decrypt_data;
1506 }
1507
PyIsCipherFile(const std::string & file_path)1508 bool PyIsCipherFile(const std::string &file_path) { return mindspore::IsCipherFile(file_path); }
1509 } // namespace pipeline
1510 } // namespace mindspore
1511