1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/session/kernel_graph_mgr.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <utility>
21 #include <functional>
22 #include <unordered_map>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "include/common/debug/anf_ir_dump.h"
26 #include "runtime/device/kernel_runtime_manager.h"
27 #include "backend/common/optimizer/common_backend_optimization.h"
28 #include "pipeline/pynative/grad/jit/jit_call_graph.h"
29 #include "utils/trace_base.h"
30 #include "ir/func_graph_cloner.h"
31 #ifndef ENABLE_SECURITY
32 #include "include/backend/debug/data_dump/dump_json_parser.h"
33 #include "include/backend/debug/data_dump/e2e_dump.h"
34 #endif
35 #include "include/common/utils/compile_cache_context.h"
36 #include "include/common/utils/config_manager.h"
37 #include "load_mindir/load_model.h"
38 #include "include/common/debug/dump_proto.h"
39
40 namespace mindspore {
41 namespace session {
42 namespace {
43 constexpr size_t kMaxDepth = 128;
44 constexpr size_t kFirstIndex = 1;
45 constexpr int64_t kPairIdx1 = 1;
46 // uint32_t max value is 4294967295
47 // graph id in graph mode start from 0
48 // graph id in pynative mode start from 4000000000
49 constexpr uint32_t kPynativeGraphIdStart = 4000000000;
50
IsGeReturnNode(const AnfNodePtr & node)51 bool IsGeReturnNode(const AnfNodePtr &node) {
52 auto context = MsContext::GetInstance();
53 MS_EXCEPTION_IF_NULL(context);
54 const bool enable_ge = context->backend_policy() == "ge";
55 if (!enable_ge) {
56 return false;
57 }
58 MS_EXCEPTION_IF_NULL(node);
59 auto cnode = node->cast<CNodePtr>();
60 if (cnode == nullptr) {
61 // parameter and value node is a real kernel too
62 return true;
63 }
64 if (cnode->size() == 0) {
65 MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
66 << trace::DumpSourceLines(node);
67 }
68 return IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), {prim::kPrimReturn});
69 }
70
RecursiveCheck(const FuncGraphManagerPtr & manager,const std::pair<AnfNodePtr,int64_t> & kernel,size_t * idx)71 bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodePtr, int64_t> &kernel, size_t *idx) {
72 auto node = kernel.first;
73 MS_EXCEPTION_IF_NULL(manager);
74 MS_EXCEPTION_IF_NULL(node);
75 if (kernel.second > kPairIdx1 && (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) ||
76 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
77 return false;
78 }
79 if ((AnfUtils::IsRealKernel(node) || IsGeReturnNode(node) || AnfAlgo::IsSummaryNode(node)) &&
80 !common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
81 return true;
82 }
83 (*idx) += 1;
84 // max recursion depth
85 if (*idx <= kMaxDepth) {
86 auto users = manager->node_users()[node];
87 if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
88 return RecursiveCheck(manager, kernel, idx);
89 })) {
90 return true;
91 }
92 }
93 return false;
94 }
95
IsUsedByRealKernel(const FuncGraphManagerPtr & manager,const AnfNodePtr & node,const uint32_t graph_id)96 bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, const uint32_t graph_id) {
97 MS_EXCEPTION_IF_NULL(manager);
98 MS_EXCEPTION_IF_NULL(node);
99 auto node_users = manager->node_users()[node];
100 // filter nodes not in current graph
101 for (auto iter = node_users.begin(); iter != node_users.end();) {
102 auto func_graph = iter->first->func_graph();
103 MS_EXCEPTION_IF_NULL(func_graph);
104 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
105 if (kernel_graph == nullptr) {
106 MS_LOG(EXCEPTION) << "func graph cast kernel graph failed, related node is: " << iter->first->DebugString();
107 }
108 if (kernel_graph->graph_id() != graph_id) {
109 iter = node_users.erase(iter);
110 } else {
111 iter++;
112 }
113 }
114
115 size_t idx = 0;
116 if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
117 return RecursiveCheck(manager, kernel, &idx);
118 })) {
119 return true;
120 }
121 return false;
122 }
123
ExistGraphCaller(const AnfNodePtr & partial_node)124 bool ExistGraphCaller(const AnfNodePtr &partial_node) {
125 MS_EXCEPTION_IF_NULL(partial_node);
126 auto partial_cnode = partial_node->cast<CNodePtr>();
127 MS_EXCEPTION_IF_NULL(partial_cnode);
128 auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
129 // If graph is nullptr, it means that the funcgraph in the partial node is a deadnode, and the processing is skipped.
130 if (partial_graph == nullptr) {
131 return false;
132 }
133 auto graph_nodes = TopoSort(partial_graph->get_return());
134 return std::any_of(graph_nodes.begin(), graph_nodes.end(), IsValueNode<FuncGraph>);
135 }
136
CheckPath(const std::optional<std::string> & path)137 bool CheckPath(const std::optional<std::string> &path) {
138 if (!path.has_value()) {
139 return false;
140 }
141 std::ifstream f(path.value());
142 bool file_is_good = f.good();
143 f.close();
144 if (!file_is_good) {
145 MS_LOG(WARNING) << "Open the compilation cache file " << path.value() << " failed.";
146 return false;
147 }
148 return true;
149 }
150
LoadJson(const std::string & filename,nlohmann::json * graph_json)151 bool LoadJson(const std::string &filename, nlohmann::json *graph_json) {
152 std::ifstream json_fs(filename);
153 if (!json_fs.is_open()) {
154 MS_LOG(ERROR) << "Open json file: " << filename << " error, backend graph cache Missed.";
155 return false;
156 }
157 try {
158 json_fs >> *graph_json;
159 json_fs.close();
160 } catch (std::exception &e) {
161 MS_LOG(INFO) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
162 json_fs.close();
163 std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalMilliSeconds));
164 std::ifstream retry_tmp(filename);
165 if (!retry_tmp.is_open()) {
166 MS_LOG(ERROR) << "Open json file: " << filename << " error, please check cached file.";
167 return false;
168 }
169 retry_tmp >> *graph_json;
170 retry_tmp.close();
171 }
172 return true;
173 }
174
175 template <typename Type>
StringToNum(const std::string & str)176 Type StringToNum(const std::string &str) {
177 std::istringstream iss(str);
178 Type num;
179 iss >> num;
180 return num;
181 }
182
LoadKernelInfoRuntimeCache(const nlohmann::json & kernel_info_value,std::shared_ptr<KernelInfoDevice> kernel_info)183 void LoadKernelInfoRuntimeCache(const nlohmann::json &kernel_info_value,
184 std::shared_ptr<KernelInfoDevice> kernel_info) {
185 if (!kernel_info_value.contains(kRuntimeCacheValid)) {
186 return;
187 }
188 auto &context = CompileCacheContext::GetInstance();
189 auto &rt = kernel_info->runtime_cache().runtime_cache();
190 rt.set_is_valid(kernel_info_value[kRuntimeCacheValid]);
191 rt.set_device_target(kernel_info_value[kRuntimeCacheDeviceTarget]);
192 rt.set_output_tensor_num(kernel_info_value[kRuntimeCacheOutputTensorNum]);
193 rt.set_real_kernel(kernel_info_value[kRuntimeCacheIsRealKernel]);
194 if (kernel_info_value.contains(kRuntimeCachePrevOutputs)) {
195 const auto &prev_outputs = kernel_info_value[kRuntimeCachePrevOutputs];
196 for (const auto &prev_output : prev_outputs) {
197 const auto &first_index = prev_output.at(0);
198 const auto &name = prev_output.at(kIndexOne);
199 const auto &second_index = prev_output.at(kIndexTwo);
200 auto output_node = context.FindBackNodeByBackName(name);
201 MS_EXCEPTION_IF_NULL(output_node);
202 rt.update_prev_node_output(first_index, std::make_pair(output_node, second_index));
203 }
204 }
205 }
206
LoadAnfSelectKernelBuildInfo(const nlohmann::json & kernel_info_value,const AnfNodePtr & node)207 void LoadAnfSelectKernelBuildInfo(const nlohmann::json &kernel_info_value, const AnfNodePtr &node) {
208 if (!kernel_info_value.contains(kHasSelectKernelBuildInfo)) {
209 return;
210 }
211 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
212 MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
213
214 auto kernel_type = kernel_info_value[kKernelType];
215 kernel_build_info_builder->SetKernelType(kernel_type);
216
217 auto op_type = kernel_info_value[kOpType];
218 kernel_build_info_builder->SetOpType(op_type);
219 auto core_type = kernel_info_value[kCoreType];
220 kernel_build_info_builder->SetCoreType(core_type);
221
222 if (kernel_info_value.contains(kOriginDataFormat)) {
223 auto origin_data_format = kernel_info_value[kOriginDataFormat];
224 kernel_build_info_builder->SetOriginDataFormat(origin_data_format);
225 }
226
227 if (kernel_info_value.contains(kAllInputFormat)) {
228 auto all_input_format = kernel_info_value[kAllInputFormat];
229 kernel_build_info_builder->SetInputsFormat(all_input_format);
230 }
231
232 auto pattern = kernel_info_value[kOpPattern];
233 kernel_build_info_builder->SetOpPattern(pattern);
234 if (kernel_info_value.contains(kAllOutputFormat)) {
235 auto all_output_format = kernel_info_value[kAllOutputFormat];
236 kernel_build_info_builder->SetOutputsFormat(all_output_format);
237 }
238 if (kernel_info_value.contains(kAllInputReshapeType)) {
239 auto all_input_reshape_type = kernel_info_value[kAllInputReshapeType];
240 kernel_build_info_builder->SetInputsReshapeType(all_input_reshape_type);
241 }
242
243 if (kernel_info_value.contains(kAllOutputReshapeType)) {
244 auto all_output_reshape_type = kernel_info_value[kAllOutputReshapeType];
245 kernel_build_info_builder->SetOutputsReshapeType(all_output_reshape_type);
246 }
247 if (kernel_info_value.contains(kAllInputDeviceType)) {
248 auto all_input_device_type = kernel_info_value[kAllInputDeviceType];
249 kernel_build_info_builder->SetInputsDeviceType(all_input_device_type);
250 }
251 if (kernel_info_value.contains(kAllOutputDeviceType)) {
252 auto all_output_device_type = kernel_info_value[kAllOutputDeviceType];
253 kernel_build_info_builder->SetOutputsDeviceType(all_output_device_type);
254 }
255 if (kernel_info_value.contains(kInputKernelObjectTypes)) {
256 auto input_kernel_object_types = kernel_info_value[kInputKernelObjectTypes];
257 kernel_build_info_builder->SetInputsKernelObjectType(input_kernel_object_types);
258 }
259 if (kernel_info_value.contains(kOutputKernelObjectTypes)) {
260 auto output_kernel_object_types = kernel_info_value[kOutputKernelObjectTypes];
261 kernel_build_info_builder->SetOutputsKernelObjectType(output_kernel_object_types);
262 }
263 if (kernel_info_value.contains(kOutputElementsKernelObjectTypes)) {
264 auto output_elements_kernel_object_types = kernel_info_value[kOutputElementsKernelObjectTypes];
265 kernel_build_info_builder->SetOutputElementsKernelObjectType(output_elements_kernel_object_types);
266 }
267
268 if (kernel_info_value.contains(kOutputDataDesc)) {
269 auto output_data_desc = kernel_info_value[kOutputDataDesc];
270 kernel_build_info_builder->SetOutputDataDesc(output_data_desc);
271 }
272 auto fusion_type = kernel_info_value[kFusionType];
273 kernel_build_info_builder->SetFusionType(fusion_type);
274 auto processor = kernel_info_value[kProcessor];
275 kernel_build_info_builder->SetProcessor(processor);
276 auto valid = kernel_info_value[kKernelBuildInfoValid];
277 kernel_build_info_builder->SetValid(valid);
278 const auto &kernel_build = kernel_build_info_builder->Build();
279 AnfAlgo::SetSelectKernelBuildInfo(kernel_build, node.get());
280 }
281
LoadAnfKernelInfoFromJson(const nlohmann::json & kernel_infos_json)282 void LoadAnfKernelInfoFromJson(const nlohmann::json &kernel_infos_json) {
283 auto &context = CompileCacheContext::GetInstance();
284 for (const auto &[name, kernel_info_value] : kernel_infos_json.items()) {
285 auto node = context.FindBackNodeByBackName(name);
286 MS_EXCEPTION_IF_NULL(node);
287 MS_LOG(DEBUG) << "Load node " << node->DebugString() << " kernel info from json.";
288 auto kernel_info = std::make_shared<device::KernelInfo>();
289 MS_EXCEPTION_IF_NULL(kernel_info);
290 if (kernel_info_value.contains(kOutInRef)) {
291 const auto &out_in_ref_json = kernel_info_value[kOutInRef];
292 for (const auto &[out, in] : out_in_ref_json.items()) {
293 kernel_info->AddRefMap(StringToNum<size_t>(out), in);
294 }
295 }
296 kernel_info->set_graph_id(kernel_info_value[kGraphId]);
297 kernel_info->set_feature_map_flag(kernel_info_value[kIsFeatureMap]);
298 node->set_kernel_info(kernel_info);
299 LoadAnfSelectKernelBuildInfo(kernel_info_value, node);
300
301 if (node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, node->cast<CNodePtr>()) &&
302 common::AnfAlgo::GetNodeAttr<bool>(node->cast<CNodePtr>(), kAttrIsUBFusionOp)) {
303 if (kernel_info_value.contains(kJsonName) && kernel_info_value.contains(kInputSizeList) &&
304 kernel_info_value.contains(kOutputSizeList)) {
305 CachedIOSizeInfo io_size;
306 io_size.json_name = kernel_info_value[kJsonName];
307 const auto &input_size_list = kernel_info_value[kInputSizeList];
308 const auto &output_size_list = kernel_info_value[kOutputSizeList];
309 (void)(io_size.input_size_list.insert(io_size.input_size_list.end(), input_size_list.begin(),
310 input_size_list.end()));
311 (void)(io_size.output_size_list.insert(io_size.output_size_list.end(), output_size_list.begin(),
312 output_size_list.end()));
313 context.PushFullnameIoSizeInfo(node->fullname_with_scope(), io_size);
314 } else {
315 MS_LOG(EXCEPTION) << "Load node " << node->DebugString() << " kernel_io_size_info failed.";
316 }
317 }
318 LoadKernelInfoRuntimeCache(kernel_info_value, kernel_info);
319 }
320 }
321
GetAnfUniqueCacheName(const AnfNodePtr & node,bool must_have_unique_name=true)322 std::string GetAnfUniqueCacheName(const AnfNodePtr &node, bool must_have_unique_name = true) {
323 MS_EXCEPTION_IF_NULL(node);
324 const auto &name = node->user_data<std::string>(kUniqueCacheName);
325 if (must_have_unique_name && name == nullptr) {
326 MS_LOG(EXCEPTION) << "The node " << node->DebugString()
327 << " has not unique name, indicating that it is not exported to mindir.";
328 }
329 return name != nullptr ? *name : "";
330 }
331
SaveAnfToAnfMap(const HashMap<AnfNodePtr,AnfNodePtr> & save_map)332 nlohmann::json SaveAnfToAnfMap(const HashMap<AnfNodePtr, AnfNodePtr> &save_map) {
333 nlohmann::json ret;
334 for (const auto &i : save_map) {
335 const auto &first_name = GetAnfUniqueCacheName(i.first, false);
336 const auto &second_name = GetAnfUniqueCacheName(i.second, false);
337 // allow some node not to be exported to mindir.
338 if (first_name.empty() || second_name.empty()) {
339 continue;
340 }
341 ret[first_name] = second_name;
342 }
343 return ret;
344 }
345
SaveAnfToAnfIndexMap(const HashMap<AnfNodePtr,AnfWithOutIndex> & save_map)346 std::vector<nlohmann::json> SaveAnfToAnfIndexMap(const HashMap<AnfNodePtr, AnfWithOutIndex> &save_map) {
347 std::vector<nlohmann::json> ret_json;
348 for (const auto &i : save_map) {
349 nlohmann::json iter_json;
350 const auto &first_name = GetAnfUniqueCacheName(i.first, false);
351 const auto &second_name = GetAnfUniqueCacheName(i.second.first, false);
352 // allow some node not to be exported to mindir.
353 if (first_name.empty() || second_name.empty()) {
354 continue;
355 }
356 iter_json.push_back(first_name);
357 iter_json.push_back(second_name);
358 iter_json.push_back(i.second.second);
359 (void)(ret_json.emplace_back(iter_json));
360 }
361 return ret_json;
362 }
363
SaveAnfIndexToAnfIndexMap(const std::map<AnfWithOutIndex,AnfWithOutIndex> & save_map)364 std::vector<nlohmann::json> SaveAnfIndexToAnfIndexMap(const std::map<AnfWithOutIndex, AnfWithOutIndex> &save_map) {
365 std::vector<nlohmann::json> ret_json;
366 for (const auto &i : save_map) {
367 nlohmann::json iter_json;
368 const auto &first_name = GetAnfUniqueCacheName(i.first.first, false);
369 const auto &second_name = GetAnfUniqueCacheName(i.second.first, false);
370 // allow some node not to be exported to mindir.
371 if (first_name.empty() || second_name.empty()) {
372 continue;
373 }
374 iter_json.push_back(first_name);
375 iter_json.push_back(i.first.second);
376 iter_json.push_back(second_name);
377 iter_json.push_back(i.second.second);
378 (void)(ret_json.emplace_back(iter_json));
379 }
380 return ret_json;
381 }
382
SaveValueSet(const HashSet<ValueNodePtr> & save_anfs)383 nlohmann::json SaveValueSet(const HashSet<ValueNodePtr> &save_anfs) {
384 nlohmann::json iter_json;
385 for (const auto &i : save_anfs) {
386 const auto &name = GetAnfUniqueCacheName(i, false);
387 // allow some value node not to be exported to mindir.
388 if (name.empty()) {
389 continue;
390 }
391 (void)(iter_json.emplace_back(name));
392 }
393 return iter_json;
394 }
395
396 template <typename T>
SaveAnfVec(const std::vector<T> & save_anfs)397 nlohmann::json SaveAnfVec(const std::vector<T> &save_anfs) {
398 nlohmann::json ret_json;
399 for (const auto &i : save_anfs) {
400 const auto &name = GetAnfUniqueCacheName(i);
401 (void)(ret_json.emplace_back(name));
402 }
403 return ret_json;
404 }
405
SaveGraphVec(const std::vector<std::weak_ptr<KernelGraph>> & save_anfs)406 nlohmann::json SaveGraphVec(const std::vector<std::weak_ptr<KernelGraph>> &save_anfs) {
407 nlohmann::json ret_json;
408 for (const auto &i : save_anfs) {
409 const auto &ptr = i.lock();
410 MS_EXCEPTION_IF_NULL(ptr);
411 (void)(ret_json.emplace_back(ptr->graph_id()));
412 }
413 return ret_json;
414 }
415
SaveGraphsId(const HashMap<uint32_t,std::weak_ptr<session::KernelGraph>> & to_save)416 nlohmann::json SaveGraphsId(const HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> &to_save) {
417 nlohmann::json ret_json;
418 for (const auto &i : to_save) {
419 nlohmann::json iter_json;
420 const auto &ptr = i.second.lock();
421 MS_EXCEPTION_IF_NULL(ptr);
422 ret_json.push_back(ptr->graph_id());
423 }
424 return ret_json;
425 }
426
SavePrevOutputs(const std::map<size_t,std::pair<AnfNodeWeakPtr,size_t>> & save_map)427 std::vector<nlohmann::json> SavePrevOutputs(const std::map<size_t, std::pair<AnfNodeWeakPtr, size_t>> &save_map) {
428 std::vector<nlohmann::json> ret_json;
429 for (const auto &i : save_map) {
430 nlohmann::json iter_json;
431 const auto &node = i.second.first.lock();
432 MS_EXCEPTION_IF_NULL(node);
433 const auto &name = GetAnfUniqueCacheName(node, false);
434 if (name.empty()) {
435 continue;
436 }
437 iter_json.push_back(i.first);
438 iter_json.push_back(name);
439 iter_json.push_back(i.second.second);
440 ret_json.push_back(iter_json);
441 }
442 return ret_json;
443 }
444
SaveKernelInfoRuntimeCache(KernelInfoDevice * kernel_info,nlohmann::json * const single_json)445 void SaveKernelInfoRuntimeCache(KernelInfoDevice *kernel_info, nlohmann::json *const single_json) {
446 MS_EXCEPTION_IF_NULL(kernel_info);
447 MS_EXCEPTION_IF_NULL(single_json);
448 const auto &rt = kernel_info->runtime_cache().runtime_cache();
449 if (!rt.is_valid()) {
450 return;
451 }
452 (*single_json)[kRuntimeCacheValid] = rt.is_valid();
453 (*single_json)[kRuntimeCacheDeviceTarget] = rt.device_target();
454 (*single_json)[kRuntimeCacheOutputTensorNum] = rt.output_tensor_num();
455 (*single_json)[kRuntimeCacheIsRealKernel] = rt.is_real_kernel();
456 const auto &prev_outputs_json = SavePrevOutputs(rt.GetPrevOutputs());
457 if (!prev_outputs_json.empty()) {
458 (*single_json)[kRuntimeCachePrevOutputs] = prev_outputs_json;
459 }
460 }
461
SaveAnfKernelInfo(const AnfNodePtr & node)462 nlohmann::json SaveAnfKernelInfo(const AnfNodePtr &node) {
463 nlohmann::json single_json;
464 if (AnfUtils::IsRealKernel(node)) {
465 single_json[kOriginDataFormat] = AnfAlgo::GetOriginDataFormat(node);
466 const auto &input_formats = AnfAlgo::GetAllInputFormats(node);
467 if (!input_formats.empty()) {
468 single_json[kAllInputFormat] = input_formats;
469 }
470 const auto &output_formats = AnfAlgo::GetAllOutputFormats(node);
471 if (!output_formats.empty()) {
472 single_json[kAllOutputFormat] = output_formats;
473 }
474 const auto &input_device_types = AnfAlgo::GetAllInputDeviceTypes(node);
475 if (!input_device_types.empty()) {
476 single_json[kAllInputDeviceType] = input_device_types;
477 }
478 const auto &output_device_types = AnfAlgo::GetAllOutputDeviceTypes(node);
479 if (!output_device_types.empty()) {
480 single_json[kAllOutputDeviceType] = output_device_types;
481 }
482 }
483 if (AnfAlgo::HasSelectKernelBuildInfo(node)) {
484 auto kernel_type = AnfAlgo::GetKernelType(node);
485 single_json[kKernelType] = kernel_type;
486 auto op_type = AnfAlgo::GetOpType(node);
487 single_json[kOpType] = op_type;
488 single_json[kCoreType] = AnfAlgo::GetCoreType(node);
489 single_json[kOpPattern] = AnfAlgo::GetOpPattern(node);
490 const auto &input_reshape_types_json = AnfAlgo::GetAllInputReshapeType(node);
491 if (!input_reshape_types_json.empty()) {
492 single_json[kAllInputReshapeType] = input_reshape_types_json;
493 }
494 const auto &output_reshape_types_json = AnfAlgo::GetAllOutputReshapeType(node);
495 if (!output_reshape_types_json.empty()) {
496 single_json[kAllOutputReshapeType] = output_reshape_types_json;
497 }
498 const auto &input_kernel_object_types_json = AnfAlgo::GetInputKernelObjectTypes(node);
499 if (!input_kernel_object_types_json.empty()) {
500 single_json[kInputKernelObjectTypes] = input_kernel_object_types_json;
501 }
502 const auto &output_kernel_object_types_json = AnfAlgo::GetOutputKernelObjectTypes(node);
503 if (!output_kernel_object_types_json.empty()) {
504 single_json[kOutputKernelObjectTypes] = output_kernel_object_types_json;
505 }
506 const auto &output_elements_kernel_object_types_json = AnfAlgo::GetOutputElementsKernelObjectTypes(node);
507 if (!output_elements_kernel_object_types_json.empty()) {
508 single_json[kOutputElementsKernelObjectTypes] = output_elements_kernel_object_types_json;
509 }
510
511 const auto &output_desc_json = AnfAlgo::GetOutputDataDesc(node);
512 if (!output_desc_json.empty()) {
513 single_json[kOutputDataDesc] = output_desc_json;
514 }
515 single_json[kFusionType] = AnfAlgo::GetFusionType(node);
516 single_json[kProcessor] = AnfAlgo::GetProcessor(node);
517 single_json[kKernelBuildInfoValid] = AnfAlgo::GetValid(node);
518 single_json[kHasSelectKernelBuildInfo] = true;
519 }
520 const auto &kernel_info = node->kernel_info();
521 MS_EXCEPTION_IF_NULL(kernel_info);
522 const auto &device_kernel_info = dynamic_cast<device::KernelInfo *>(kernel_info);
523 MS_EXCEPTION_IF_NULL(device_kernel_info);
524 nlohmann::json out_in_ref_json;
525 const auto &out_in_ref = device_kernel_info->out_in_ref_map();
526 (void)(std::for_each(out_in_ref.begin(), out_in_ref.end(),
527 [&out_in_ref_json](const auto &iter) { out_in_ref_json[iter.first] = iter.second; }));
528 if (!out_in_ref_json.empty()) {
529 single_json[kOutInRef] = out_in_ref_json;
530 }
531 const auto &graph_id = device_kernel_info->graph_id();
532 single_json[kGraphId] = graph_id;
533 const auto &is_feature_map = device_kernel_info->is_feature_map();
534 single_json[kIsFeatureMap] = is_feature_map;
535
536 if (node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, node->cast<CNodePtr>()) &&
537 common::AnfAlgo::GetNodeAttr<bool>(node->cast<CNodePtr>(), kAttrIsUBFusionOp)) {
538 auto &context = CompileCacheContext::GetInstance();
539 const auto &io_size = context.GetIOSizeInfo(node->fullname_with_scope());
540 single_json[kJsonName] = io_size.json_name;
541 const auto input_size_list_json = io_size.input_size_list;
542 if (!input_size_list_json.empty()) {
543 single_json[kInputSizeList] = input_size_list_json;
544 }
545 const auto output_size_list_json = io_size.output_size_list;
546 if (!output_size_list_json.empty()) {
547 single_json[kOutputSizeList] = output_size_list_json;
548 }
549 }
550 SaveKernelInfoRuntimeCache(kernel_info, &single_json);
551 return single_json;
552 }
553
SaveBackendParamToFrontendParamIndex(const KernelGraphPtr & kernel_graph,const FuncGraph * front_graph)554 nlohmann::json SaveBackendParamToFrontendParamIndex(const KernelGraphPtr &kernel_graph, const FuncGraph *front_graph) {
555 nlohmann::json ret;
556 const auto ¶ms = kernel_graph->parameters();
557 auto &context = CompileCacheContext::GetInstance();
558 const auto &front_params = front_graph->parameters();
559 for (const auto ¶m : params) {
560 if (!context.IsBackendParamGenFromFrontendParam(param)) {
561 continue;
562 }
563 const auto &name = param->user_data<std::string>(kUniqueCacheName);
564 MS_EXCEPTION_IF_NULL(name);
565 const auto &front_param = kernel_graph->GetFrontAnfByBackendAnf(param);
566 MS_EXCEPTION_IF_NULL(front_param);
567 auto iter = std::find(front_params.begin(), front_params.end(), front_param);
568 if (iter == front_params.end()) {
569 MS_LOG(EXCEPTION) << "Backend param " << param->DebugString() << " correspond frontend param "
570 << front_param->DebugString() << " can not find in frontend graph params.";
571 }
572 ret[*name] = std::distance(front_params.begin(), iter);
573 }
574 return ret;
575 }
576
SaveNodesKernelInfoAndParamsName(const KernelGraphPtr & kg,const std::vector<AnfNodePtr> & isolated_nodes,nlohmann::json * const kg_json)577 void SaveNodesKernelInfoAndParamsName(const KernelGraphPtr &kg, const std::vector<AnfNodePtr> &isolated_nodes,
578 nlohmann::json *const kg_json) {
579 std::vector<AnfNodePtr> nodes = TopoSort(kg->get_return(), SuccIncoming, AlwaysInclude);
580 nlohmann::json kernels_info_json;
581 (void)(nodes.insert(nodes.end(), isolated_nodes.begin(), isolated_nodes.end()));
582 const auto ¶ms = kg->parameters();
583 std::vector<AnfNodePtr> isolated_params;
584 (void)(std::set_difference(params.begin(), params.end(), nodes.begin(), nodes.end(),
585 std::back_inserter(isolated_params)));
586 (void)(nodes.insert(nodes.end(), isolated_params.begin(), isolated_params.end()));
587 nlohmann::json param_unique_name_to_name;
588 for (const auto &node : nodes) {
589 MS_EXCEPTION_IF_NULL(node);
590 if (node->kernel_info() == nullptr && !node->isa<Parameter>()) {
591 continue;
592 }
593 const auto &name = GetAnfUniqueCacheName(node);
594 if (node->isa<Parameter>()) {
595 auto param = node->cast<ParameterPtr>();
596 param_unique_name_to_name[name] = param->name();
597 }
598 if (node->kernel_info() == nullptr) {
599 MS_LOG(WARNING) << "The node " << node->DebugString() << " has not kernel_info.";
600 continue;
601 }
602 const auto &kernel_info_json = SaveAnfKernelInfo(node);
603 if (!kernel_info_json.empty()) {
604 kernels_info_json[name] = kernel_info_json;
605 }
606 }
607 (*kg_json)[kParameterUniqueNameToName] = param_unique_name_to_name;
608 (*kg_json)[kNodesKernelInfo] = kernels_info_json;
609 }
610
SaveSummaryNodes(const std::map<std::string,std::pair<AnfNodePtr,int>> & save_map)611 std::vector<nlohmann::json> SaveSummaryNodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &save_map) {
612 std::vector<nlohmann::json> ret_json;
613 for (const auto &i : save_map) {
614 nlohmann::json iter_json;
615 const auto &first = i.first;
616 const auto &node = i.second.first;
617 const auto &name = GetAnfUniqueCacheName(node);
618 const auto &index = i.second.second;
619 iter_json.push_back(first);
620 iter_json.push_back(name);
621 iter_json.push_back(index);
622 ret_json.push_back(iter_json);
623 }
624 return ret_json;
625 }
626
GenKernelGraphJson(const KernelGraphPtr & kg,const std::vector<AnfNodePtr> & isolated_nodes)627 nlohmann::json GenKernelGraphJson(const KernelGraphPtr &kg, const std::vector<AnfNodePtr> &isolated_nodes) {
628 nlohmann::json kg_json;
629 SaveNodesKernelInfoAndParamsName(kg, isolated_nodes, &kg_json);
630 kg_json[kGraphId] = kg->graph_id();
631 kg_json[kRunMode] = kg->RunMode();
632 kg_json[kIsLoopCountSink] = kg->is_loop_count_sink();
633 kg_json[kIsDynamicShape] = kg->is_dynamic_shape();
634 kg_json[kDeviceTarget] = kg->device_target();
635 kg_json[kRootGraphId] = kg->root_graph_id();
636 kg_json[kExecutable] = kg->executable();
637 kg_json[kHasRecursiveCall] = kg->recursive_call();
638 kg_json[kHasSubgraphMultiCall] = kg->subgraph_multi_call();
639 kg_json[kNeedInline] = kg->need_inline();
640 kg_json[kIsNeedGil] = kg->is_need_gil();
641 kg_json[kIsFromSingleOp] = kg->is_from_single_op();
642 kg_json[kLabelNum] = kg->label_num();
643 #ifndef ENABLE_SECURITY
644 kg_json[kSummaryNodeExist] = kg->summary_node_exist();
645 #endif
646 const auto &back_front_anf_json = SaveAnfToAnfMap(kg->backend_front_anf_map());
647 if (!back_front_anf_json.empty()) {
648 kg_json[kBackendFrontAnf] = back_front_anf_json;
649 }
650 const auto &internal_params_to_front_node_json = SaveAnfToAnfIndexMap(kg->InternalParameterToFrontNodeMap());
651 if (!internal_params_to_front_node_json.empty()) {
652 kg_json[kInternalParameterToFrontNode] = internal_params_to_front_node_json;
653 }
654 const auto &ref_in_out_map_json = SaveAnfIndexToAnfIndexMap(kg->GetRefMap());
655 if (!ref_in_out_map_json.empty()) {
656 kg_json[kRefInOutMap] = ref_in_out_map_json;
657 }
658 const auto &graph_value_nodes = SaveValueSet(kg->graph_value_nodes());
659 if (!graph_value_nodes.empty()) {
660 kg_json[kGraphValueNodes] = graph_value_nodes;
661 }
662 const auto &exec_order_json = SaveAnfVec(kg->execution_order());
663 if (!exec_order_json.empty()) {
664 kg_json[kExecutionOrder] = exec_order_json;
665 }
666 const auto &inputs_json = SaveAnfVec(kg->inputs());
667 if (!inputs_json.empty()) {
668 kg_json[kInputs] = inputs_json;
669 }
670 const auto ¶meters_json = SaveAnfVec(kg->parameters());
671 if (!parameters_json.empty()) {
672 kg_json[kParameters] = parameters_json;
673 }
674 const auto &child_graph_result_json = SaveAnfVec(kg->child_graph_result());
675 if (!child_graph_result_json.empty()) {
676 kg_json[kChildGraphResult] = child_graph_result_json;
677 }
678 const auto &child_graph_order_json = SaveGraphVec(kg->child_graph_order());
679 if (!child_graph_order_json.empty()) {
680 kg_json[kChildGraphOrder] = child_graph_order_json;
681 }
682 const auto &start = kg->get_start_label();
683 if (start) {
684 kg_json[kStartLabel] = GetAnfUniqueCacheName(start);
685 }
686 const auto &end = kg->get_end_goto();
687 if (end) {
688 kg_json[kEndGoto] = GetAnfUniqueCacheName(end);
689 }
690 const auto &valid_inputs = kg->valid_inputs();
691 if (!valid_inputs.empty()) {
692 kg_json[kValidInputs] = valid_inputs;
693 }
694 const auto &pre_graphs_json = SaveGraphsId(kg->get_pre_graphs());
695 if (!pre_graphs_json.empty()) {
696 kg_json[kPreGraphs] = pre_graphs_json;
697 }
698 const auto &post_graphs_json = SaveGraphsId(kg->GetPostGraphs());
699 if (!post_graphs_json.empty()) {
700 kg_json[kPostGraphs] = post_graphs_json;
701 }
702 const auto &index_set = kg->CommSubGraphIds();
703 if (!index_set.empty()) {
704 kg_json[kCommSubGraphIds] = index_set;
705 }
706 #ifndef ENABLE_SECURITY
707 const auto &summary_nodes_json = SaveSummaryNodes(kg->summary_nodes());
708 if (!summary_nodes_json.empty()) {
709 kg_json[kSummaryNodes] = summary_nodes_json;
710 }
711 #endif
712 auto &context = CompileCacheContext::GetInstance();
713 auto front_graph = context.GetFrontendGraphByBackendGraph(kg);
714 if (front_graph) {
715 kg_json[kCorrespondFrontendGraph] = front_graph->ToString();
716 }
717 kg_json[kBackendParamToFrontendParamIndex] = SaveBackendParamToFrontendParamIndex(kg, front_graph);
718 return kg_json;
719 }
720
DumpKernelGraphJson(const KernelGraphPtr & root_graph,const std::set<KernelGraphPtr> & child_graphs,const std::map<KernelGraphPtr,std::vector<AnfNodePtr>> & isolated_nodes_map,const std::string & path)721 bool DumpKernelGraphJson(const KernelGraphPtr &root_graph, const std::set<KernelGraphPtr> &child_graphs,
722 const std::map<KernelGraphPtr, std::vector<AnfNodePtr>> &isolated_nodes_map,
723 const std::string &path) {
724 nlohmann::json kg_json;
725 kg_json[root_graph->ToString()] = GenKernelGraphJson(root_graph, isolated_nodes_map.find(root_graph)->second);
726 for (const auto &graph : child_graphs) {
727 kg_json[graph->ToString()] = GenKernelGraphJson(graph, isolated_nodes_map.find(graph)->second);
728 }
729 return Common::SaveStringToFile(path, kg_json.dump());
730 }
731
GetAllChildGraph(const KernelGraphPtr & kg,std::set<KernelGraphPtr> * visit,std::set<KernelGraphPtr> * graphs)732 void GetAllChildGraph(const KernelGraphPtr &kg, std::set<KernelGraphPtr> *visit, std::set<KernelGraphPtr> *graphs) {
733 if (kg == nullptr || kg->IsLeafGraph()) {
734 return;
735 }
736 MS_EXCEPTION_IF_NULL(visit);
737 MS_EXCEPTION_IF_NULL(graphs);
738 if (visit->find(kg) != visit->end()) {
739 return;
740 }
741 const auto &order = kg->child_graph_order();
742 for (auto iter : order) {
743 auto graph = iter.lock();
744 MS_EXCEPTION_IF_NULL(graph);
745 (void)(graphs->insert(graph));
746 }
747 (void)(visit->insert(kg));
748
749 for (auto &i : order) {
750 GetAllChildGraph(i.lock(), visit, graphs);
751 }
752 }
753
GetIsolatedNodes(const KernelGraphPtr & kg,std::vector<AnfNodePtr> * isolated_nodes)754 void GetIsolatedNodes(const KernelGraphPtr &kg, std::vector<AnfNodePtr> *isolated_nodes) {
755 MS_EXCEPTION_IF_NULL(kg);
756 const auto &orders = kg->execution_order();
757 std::vector<AnfNodePtr> possible_isolated(orders.begin(), orders.end());
758 const auto &start = kg->get_start_label();
759 if (start && std::find(possible_isolated.begin(), possible_isolated.end(), start) == possible_isolated.end()) {
760 possible_isolated.push_back(start);
761 }
762 const auto &end = kg->get_end_goto();
763 if (end && std::find(possible_isolated.begin(), possible_isolated.end(), end) == possible_isolated.end()) {
764 possible_isolated.push_back(end);
765 }
766 auto topo_nodes = TopoSort(kg->get_return(), SuccIncoming, AlwaysInclude);
767 (void)(std::set_difference(possible_isolated.begin(), possible_isolated.end(), topo_nodes.begin(), topo_nodes.end(),
768 std::back_inserter(*isolated_nodes)));
769 }
770
HandleParamExistCorrespondFrontendParam(const KernelGraphPtr & graph)771 void HandleParamExistCorrespondFrontendParam(const KernelGraphPtr &graph) {
772 MS_EXCEPTION_IF_NULL(graph);
773 auto &context = CompileCacheContext::GetInstance();
774 const auto &front_graph = context.GetFrontendGraphByBackendGraph(graph);
775 if (!front_graph) {
776 return;
777 }
778 const auto ¶ms = graph->parameters();
779 const auto &front_params = front_graph->parameters();
780 for (const auto ¶m : params) {
781 auto front_param = graph->GetFrontAnfByBackendAnf(param);
782 if (!front_param) {
783 continue;
784 }
785 auto iter = std::find(front_params.begin(), front_params.end(), front_param);
786 if (iter != front_params.end()) {
787 context.InsertBackendParamGenFromFrontendParam(param);
788 }
789 }
790 }
791
NeedConvertValueNodeToParameter(const AnfNodePtr & node)792 bool NeedConvertValueNodeToParameter(const AnfNodePtr &node) {
793 MS_EXCEPTION_IF_NULL(node);
794 auto ms_context = MsContext::GetInstance();
795 MS_EXCEPTION_IF_NULL(ms_context);
796 if (ms_context->backend_policy() != "ge" || ms_context->IsKByKExecutorMode()) {
797 return false;
798 }
799 if (!node->isa<ValueNode>()) {
800 return false;
801 }
802 auto value_node = node->cast<ValueNodePtr>();
803 MS_EXCEPTION_IF_NULL(value_node);
804 auto value = value_node->value();
805 MS_EXCEPTION_IF_NULL(value);
806 if (value->isa<tensor::Tensor>()) {
807 auto tensor = value->cast<tensor::TensorPtr>();
808 MS_EXCEPTION_IF_NULL(tensor);
809 if (tensor->is_forward_output()) {
810 return true;
811 }
812 }
813 return false;
814 }
815
ConvertValueNodeToParameter(const KernelGraphPtr & graph,const AnfNodePtr & node,std::vector<ParameterPtr> * added_parameters)816 void ConvertValueNodeToParameter(const KernelGraphPtr &graph, const AnfNodePtr &node,
817 std::vector<ParameterPtr> *added_parameters) {
818 MS_EXCEPTION_IF_NULL(graph);
819 MS_EXCEPTION_IF_NULL(node);
820 MS_EXCEPTION_IF_NULL(added_parameters);
821 auto graph_inputs = graph->MutableInputs();
822 MS_EXCEPTION_IF_NULL(graph_inputs);
823 auto new_parameter = graph->NewParameter(node->abstract());
824 MS_EXCEPTION_IF_NULL(new_parameter);
825 new_parameter->IncreaseUsedGraphCount();
826 graph_inputs->push_back(new_parameter);
827
828 MS_EXCEPTION_IF_NULL(node->cast<ValueNodePtr>());
829 new_parameter->set_user_data(kForwardOutput, node->cast<ValueNodePtr>()->value());
830 graph->FrontBackendMapAdd(node, new_parameter);
831 (void)added_parameters->emplace_back(new_parameter);
832 MS_LOG(DEBUG) << "Replace ValueNode " << node->DebugString() << " with Parameter " << new_parameter->DebugString();
833 }
834 } // namespace
835
GetParamDefaultValue(const AnfNodePtr & node)836 ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
837 if (node == nullptr) {
838 return nullptr;
839 }
840 auto parameter = node->cast<ParameterPtr>();
841 if (parameter == nullptr || !parameter->has_default()) {
842 return nullptr;
843 }
844 return parameter->param_info();
845 }
846
847 #ifndef ENABLE_SECURITY
ExistSummaryNode(const KernelGraph * graph)848 bool ExistSummaryNode(const KernelGraph *graph) {
849 MS_EXCEPTION_IF_NULL(graph);
850 for (auto &n : TopoSort(graph->get_return())) {
851 if (AnfAlgo::IsSummaryNode(n)) {
852 return true;
853 }
854 }
855 return false;
856 }
857 #endif
858
859 GraphId KernelGraphMgr::graph_sum_ = 0;
860 GraphId KernelGraphMgr::pynative_graph_sum_ = kPynativeGraphIdStart;
861 HashMap<std::string, std::weak_ptr<AnfNode>> KernelGraphMgr::name_to_params_ =
862 HashMap<std::string, std::weak_ptr<AnfNode>>();
863
CreateNewValueNode(const AnfNodePtr & anf,KernelGraph * graph) const864 ValueNodePtr KernelGraphMgr::CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) const {
865 MS_EXCEPTION_IF_NULL(anf);
866 MS_EXCEPTION_IF_NULL(graph);
867 auto value_node = anf->cast<ValueNodePtr>();
868 MS_EXCEPTION_IF_NULL(value_node);
869 auto value = value_node->value();
870 MS_EXCEPTION_IF_NULL(value);
871 // Copy data from device if the tensor is an output of Op or Graph.
872 if (value->isa<tensor::Tensor>()) {
873 auto tensor = value->cast<TensorPtr>();
874 MS_EXCEPTION_IF_NULL(tensor);
875 if (!tensor->is_forward_output() && !tensor->is_parameter()) {
876 tensor->data_sync();
877 MS_LOG(INFO) << "Data sync for Tensor " << tensor->ToString();
878 }
879 }
880 auto new_value_node = value_node;
881 if (!graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
882 new_value_node = graph->NewValueNode(value_node);
883 graph->FrontBackendMapAdd(anf, new_value_node);
884 }
885 graph->AddValueNodeToGraph(new_value_node);
886 return new_value_node;
887 }
888
GetGraphIdByNode(const AnfNodePtr & front_anf) const889 GraphId KernelGraphMgr::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
890 for (const auto &graph_item : graphs_) {
891 auto graph = graph_item.second;
892 MS_EXCEPTION_IF_NULL(graph);
893 // if front_anf is a parameter,the backend parameter may have two
894 if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
895 return graph_item.first;
896 }
897 }
898 MS_EXCEPTION_IF_NULL(front_anf);
899 MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
900 return kInvalidGraphId;
901 }
902
GetGraph(mindspore::GraphId graph_id) const903 KernelGraphPtr KernelGraphMgr::GetGraph(mindspore::GraphId graph_id) const {
904 auto it = graphs_.find(graph_id);
905 if (it == graphs_.end()) {
906 MS_LOG(INFO) << "Can't find graph " << graph_id;
907 return nullptr;
908 }
909 return it->second;
910 }
911
ClearGraph()912 void KernelGraphMgr::ClearGraph() {
913 auto graph_iter = graphs_.begin();
914 while (graph_iter != graphs_.end()) {
915 graph_iter->second.reset();
916 graph_iter = graphs_.erase(graph_iter);
917 }
918 graph_sum_ = 0;
919 pynative_graph_sum_ = kPynativeGraphIdStart;
920 }
921
InitInternalOutputParameter(const AnfNodePtr & out_node,const AnfNodePtr & parameter) const922 void KernelGraphMgr::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) const {
923 MS_EXCEPTION_IF_NULL(out_node);
924 MS_EXCEPTION_IF_NULL(parameter);
925 MS_LOG(DEBUG) << "parameter:" << parameter->DebugString()
926 << " abstract:" << (parameter->abstract() != nullptr ? parameter->abstract()->ToString() : "null");
927 auto graph_id = GetGraphIdByNode(out_node);
928 if (graph_id == kInvalidGraphId) {
929 return;
930 }
931 auto node_graph = GetGraph(graph_id);
932 if (node_graph == nullptr) {
933 return;
934 }
935 MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
936 auto ref_node_with_index = node_graph->GetInternalOutputByFrontNode(out_node);
937 auto ref_node = ref_node_with_index.first;
938 if (ref_node == nullptr) {
939 MS_LOG(INFO) << "No corresponding internal output for output node";
940 return;
941 }
942 size_t output_idx = ref_node_with_index.second;
943 if (common::AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
944 output_idx = common::AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
945 }
946 auto real_kernel = common::AnfAlgo::VisitKernel(ref_node, output_idx);
947 auto ref_real_node = real_kernel.first;
948 MS_EXCEPTION_IF_NULL(ref_real_node);
949 auto ref_real_node_index = real_kernel.second;
950 if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
951 auto kernel_info = ref_real_node->kernel_info();
952 if (kernel_info == nullptr || !kernel_info->has_build_info()) {
953 MS_LOG(INFO) << "No kernel info";
954 return;
955 }
956 if (!common::AnfAlgo::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
957 MS_LOG(INFO) << "No kernel address";
958 return;
959 }
960 if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index, true)) {
961 return;
962 }
963
964 // Update the kernel build info.
965 auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
966 auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
967 if (type == TypeId::kTypeUnknown) {
968 return;
969 }
970 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
971 builder.SetOutputsDeviceType({type});
972 builder.SetOutputsFormat({format});
973 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), parameter.get());
974
975 abstract::AbstractBasePtr abstract;
976 auto shape = parameter->Shape();
977 MS_EXCEPTION_IF_NULL(shape);
978 if (shape->isa<abstract::NoShape>()) {
979 abstract = std::make_shared<abstract::AbstractScalar>(TypeIdToType(type));
980 } else if (shape->isa<abstract::DynamicSequenceShape>()) {
981 return;
982 } else {
983 abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape->cast<abstract::BaseShapePtr>());
984 }
985 if (!parameter->abstract()->isa<abstract::AbstractAny>()) {
986 parameter->set_abstract(abstract);
987 }
988 }
989 }
990
CreateParameterFromTuple(const AnfNodePtr & node,KernelGraph * graph) const991 AnfNodePtr KernelGraphMgr::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) const {
992 MS_EXCEPTION_IF_NULL(node);
993 MS_EXCEPTION_IF_NULL(graph);
994 auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
995 auto parameters = common::AnfAlgo::GetAllOutput(new_parameter);
996 std::vector<AnfNodePtr> pre_graph_out = {node};
997 // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
998 if (!pre_graph_out.empty() && !AnfUtils::IsRealKernel(node)) {
999 pre_graph_out = common::AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
1000 }
1001
1002 for (size_t i = 0; i < parameters.size(); ++i) {
1003 const auto ¶meter = parameters[i];
1004 auto context_ptr = MsContext::GetInstance();
1005 MS_EXCEPTION_IF_NULL(context_ptr);
1006 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1007 // In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
1008 // which needs to be linked when processing the internal node.
1009 graph->CacheInternalParameterToFrontNode(parameter, {node, i});
1010 }
1011 auto valid_inputs = graph->MutableValidInputs();
1012 MS_EXCEPTION_IF_NULL(valid_inputs);
1013 auto graph_inputs = graph->MutableInputs();
1014 MS_EXCEPTION_IF_NULL(graph_inputs);
1015 valid_inputs->push_back(true);
1016 graph_inputs->push_back(parameter);
1017 }
1018 size_t param_index = 0;
1019 for (const auto &out_node : pre_graph_out) {
1020 size_t output_size = AnfAlgo::GetOutputElementNum(out_node);
1021 for (size_t i = 0; i < output_size; i++) {
1022 if (param_index >= parameters.size()) {
1023 MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
1024 << ",out_node:" << out_node->DebugString();
1025 }
1026 InitInternalOutputParameter(out_node, parameters[param_index++]);
1027 }
1028 }
1029 return new_parameter;
1030 }
1031
CreateNewParameterFromParameter(const AnfNodePtr & anf,KernelGraph * graph)1032 ParameterPtr KernelGraphMgr::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
1033 MS_EXCEPTION_IF_NULL(anf);
1034 if (!anf->isa<Parameter>()) {
1035 MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
1036 }
1037 MS_EXCEPTION_IF_NULL(graph);
1038 auto param_value = GetParamDefaultValue(anf);
1039 auto valid_inputs = graph->MutableValidInputs();
1040 MS_EXCEPTION_IF_NULL(valid_inputs);
1041 auto graph_inputs = graph->MutableInputs();
1042 MS_EXCEPTION_IF_NULL(graph_inputs);
1043 ParameterPtr new_parameter = nullptr;
1044 auto func_graph = anf->func_graph();
1045 MS_EXCEPTION_IF_NULL(func_graph);
1046 bool is_pynative_bprop_kernel_graph = graph->has_flag(kFlagIsPyNativeBpropKernelGraph);
1047 if (func_graph->manager() != nullptr && func_graph->exist_multi_target() &&
1048 graph->device_target() == device::DeviceType::kCPU) {
1049 auto iter = default_param_map_.find(anf);
1050 if (iter != default_param_map_.end()) {
1051 new_parameter = iter->second;
1052 }
1053 if (new_parameter != nullptr) {
1054 graph_inputs->push_back(new_parameter);
1055 MS_LOG(DEBUG) << "create new parameter for parameter:" << anf->DebugString() << " for graph:" << graph->ToString()
1056 << " backend node:" << new_parameter->DebugString();
1057 return new_parameter;
1058 }
1059 TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1060 new_parameter = anf->cast<ParameterPtr>();
1061 if (!is_pynative_bprop_kernel_graph) {
1062 new_parameter = graph->NewParameter(new_parameter);
1063 }
1064 graph_inputs->push_back(new_parameter);
1065 valid_inputs->push_back(true);
1066 default_param_map_[anf] = new_parameter;
1067 return new_parameter;
1068 }
1069 // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
1070 if (!is_pynative_bprop_kernel_graph) {
1071 auto context = MsContext::GetInstance();
1072 if (!context->IsKByKExecutorMode() && param_value != nullptr) {
1073 new_parameter = param_value->parameter();
1074 }
1075 if (new_parameter == nullptr) {
1076 TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1077 new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1078
1079 auto input_node_iter = partial_parameters_map_.find(anf);
1080 if (input_node_iter != partial_parameters_map_.end()) {
1081 InitInternalOutputParameter(input_node_iter->second, new_parameter);
1082 }
1083
1084 if (param_value != nullptr) {
1085 param_value->set_parameter(new_parameter);
1086 }
1087 }
1088 new_parameter->IncreaseUsedGraphCount();
1089 } else {
1090 new_parameter = anf->cast<ParameterPtr>();
1091 }
1092 (void)graph_inputs->emplace_back(new_parameter);
1093 (void)valid_inputs->emplace_back(true);
1094 return new_parameter;
1095 }
1096
CreateNewParameterFromCNode(const AnfNodePtr & anf,KernelGraph * graph)1097 AnfNodePtr KernelGraphMgr::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
1098 MS_EXCEPTION_IF_NULL(anf);
1099 MS_EXCEPTION_IF_NULL(graph);
1100 MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
1101 return CreateParameterFromTuple(anf, graph);
1102 }
1103
GetCNodeInfo(const CNodePtr & cnode,std::vector<AnfNodePtr> * cnode_inputs) const1104 void KernelGraphMgr::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const {
1105 MS_EXCEPTION_IF_NULL(cnode);
1106 MS_EXCEPTION_IF_NULL(cnode_inputs);
1107 auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
1108 if (prim != nullptr) {
1109 // push attr to inputs[0] of new cnode
1110 cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
1111 } else {
1112 auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
1113 MS_EXCEPTION_IF_NULL(fg);
1114 auto new_fg = BasicClone(fg);
1115 cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
1116 }
1117 }
1118
GetChildGraph(KernelGraph * graph,const AnfNodePtr & child_func_graph)1119 AnfNodePtr KernelGraphMgr::GetChildGraph(KernelGraph *graph, const AnfNodePtr &child_func_graph) {
1120 MS_EXCEPTION_IF_NULL(child_func_graph);
1121 std::vector<KernelGraphPtr> all_graphs;
1122 FuncGraphPtr child_graph = common::AnfAlgo::GetValueNodeFuncGraph(child_func_graph);
1123 MS_EXCEPTION_IF_NULL(child_graph);
1124 if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
1125 (void)ConstructKernelGraph(child_graph, &all_graphs, graph->device_target());
1126 }
1127 auto new_value_node = graph->GetBackendAnfByFrontAnf(child_func_graph);
1128 if (new_value_node != nullptr) {
1129 return new_value_node;
1130 }
1131 new_value_node = CreateValueNodeKernelGraph(child_func_graph, graph);
1132 MS_EXCEPTION_IF_NULL(new_value_node);
1133 return new_value_node;
1134 }
1135
1136 namespace {
AddValueNode(const AnfNodePtr & backend_node,KernelGraph * graph)1137 void AddValueNode(const AnfNodePtr &backend_node, KernelGraph *graph) {
1138 if (backend_node->isa<ValueNode>() && !IsValueNode<FuncGraph>(backend_node)) {
1139 graph->AddValueNodeToGraph(backend_node->cast<ValueNodePtr>());
1140 }
1141 }
1142 } // namespace
1143
GetNewCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * other_graph_cnode)1144 void KernelGraphMgr::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
1145 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
1146 MS_EXCEPTION_IF_NULL(cnode);
1147 MS_EXCEPTION_IF_NULL(graph);
1148 MS_EXCEPTION_IF_NULL(other_graph_cnode);
1149 MS_EXCEPTION_IF_NULL(cnode_inputs);
1150 auto origin_inputs = cnode->inputs();
1151 const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
1152 auto context = MsContext::GetInstance();
1153 MS_EXCEPTION_IF_NULL(context);
1154 const bool enable_ge = context->backend_policy() == "ge";
1155 AnfNodePtr child_func_graph = nullptr;
1156 std::vector<AnfNodePtr> params;
1157 // if has multiple depends,only select first depend as parameter
1158 for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
1159 auto anf = origin_inputs[input_idx];
1160 MS_EXCEPTION_IF_NULL(anf);
1161 // anf has been created before
1162 if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
1163 const auto &backend_node = graph->GetBackendAnfByFrontAnf(anf);
1164 (void)params.emplace_back(backend_node);
1165 AddValueNode(backend_node, graph);
1166 continue;
1167 } else if ((is_depend && input_idx > kRealInputIndexInDepend && !enable_ge)) {
1168 (void)params.emplace_back(graph->NewValueNode(std::make_shared<Tensor>(SizeToInt(input_idx))));
1169 continue;
1170 } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
1171 (void)params.emplace_back((*other_graph_cnode)[anf]);
1172 continue;
1173 } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
1174 // if input is a value node,
1175 auto new_value_node = CreateNewValueNode(anf, graph);
1176 if (new_value_node != nullptr) {
1177 (void)params.emplace_back(new_value_node);
1178 }
1179 continue;
1180 } else if (anf->isa<Parameter>()) {
1181 auto new_parameter = CreateNewParameterFromParameter(anf, graph);
1182 MS_EXCEPTION_IF_NULL(new_parameter);
1183 MS_LOG(DEBUG) << "Create new parameter:" << new_parameter->DebugString()
1184 << " by front parameter:" << anf->DebugString();
1185 (void)params.emplace_back(new_parameter);
1186 graph->FrontBackendMapAdd(anf, new_parameter);
1187 continue;
1188 } else if (IsValueNode<FuncGraph>(anf) && cnode->HasPrimalAttr(kAttrNotCut)) {
1189 MS_EXCEPTION_IF_CHECK_FAIL(input_idx == 1, "Graph input index is not 1, anf: " + anf->DebugString() +
1190 ", index: " + std::to_string(input_idx));
1191 child_func_graph = anf;
1192 continue;
1193 } else {
1194 // the input node is a cnode from other graph
1195 auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
1196 if (parameter_from_cnode == nullptr) {
1197 parameter_from_cnode = NewValueNode(MakeValue(SizeToLong(input_idx)));
1198 }
1199 MS_EXCEPTION_IF_NULL(parameter_from_cnode);
1200 MS_LOG(DEBUG) << "graph:" << graph->ToString() << " front node:" << anf->DebugString()
1201 << " abstract:" << (anf->abstract() != nullptr ? anf->abstract()->ToString() : "null")
1202 << " parameter:" << parameter_from_cnode->DebugString() << " abstract:"
1203 << (parameter_from_cnode->abstract() != nullptr ? parameter_from_cnode->abstract()->ToString()
1204 : "null");
1205 if (parameter_from_cnode->isa<Parameter>() && IsPrimitiveCNode(anf, prim::kPrimLoad)) {
1206 auto para = parameter_from_cnode->cast<ParameterPtr>();
1207 auto load_cnode = anf->cast<CNodePtr>();
1208 para->set_name(load_cnode->fullname_with_scope());
1209 }
1210 (void)params.emplace_back(parameter_from_cnode);
1211 (*other_graph_cnode)[anf] = parameter_from_cnode;
1212 }
1213 }
1214
1215 if (child_func_graph != nullptr) {
1216 (void)cnode_inputs->emplace_back(GetChildGraph(graph, child_func_graph));
1217 }
1218 (void)std::copy(params.begin(), params.end(), std::back_inserter(*cnode_inputs));
1219 }
1220
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * other_graph_cnode)1221 CNodePtr KernelGraphMgr::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
1222 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
1223 MS_EXCEPTION_IF_NULL(cnode);
1224 MS_EXCEPTION_IF_NULL(graph);
1225 MS_EXCEPTION_IF_NULL(other_graph_cnode);
1226 auto primitive_input = cnode->input(kAnfPrimitiveIndex);
1227 // control flow sink to GE
1228 bool need_control_flow_sink =
1229 IsPrimitiveCNode(primitive_input, prim::kPrimSwitch) && cnode->HasPrimalAttr(kAttrNotCut);
1230 // backend inline
1231 bool need_backend_inline = cnode->HasPrimalAttr(kAttrNeedInline);
1232 if (need_backend_inline) {
1233 auto fn = cnode->input(kAnfPrimitiveIndex);
1234 MS_EXCEPTION_IF_NULL(fn);
1235 if (IsValueNode<FuncGraph>(fn)) {
1236 // Need to create a new kernel graph
1237 (void)GetChildGraph(graph, fn);
1238 }
1239 }
1240 if (need_control_flow_sink || need_backend_inline) {
1241 auto new_cnode = CreateNewCNode(cnode, graph);
1242 MS_EXCEPTION_IF_NULL(new_cnode);
1243 FlattenTuple(new_cnode);
1244 if (need_backend_inline) {
1245 new_cnode->AddPrimalAttr(kAttrNeedInline, MakeValue(true));
1246 }
1247 MS_LOG(DEBUG) << "Create new call node:" << new_cnode->DebugString() << " by front node:" << cnode->DebugString();
1248 return new_cnode;
1249 }
1250 // get primitive of old node
1251 std::vector<AnfNodePtr> cnode_inputs;
1252 GetCNodeInfo(cnode, &cnode_inputs);
1253 GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
1254 TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
1255 auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
1256 return new_cnode;
1257 }
1258
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph)1259 CNodePtr KernelGraphMgr::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
1260 MS_EXCEPTION_IF_NULL(cnode);
1261 MS_EXCEPTION_IF_NULL(graph);
1262 std::vector<AnfNodePtr> cnode_inputs;
1263 auto attr_input = cnode->input(kAnfPrimitiveIndex);
1264 MS_EXCEPTION_IF_NULL(attr_input);
1265 if (IsValueNode<FuncGraph>(attr_input)) {
1266 // cnode is a graph or a call
1267 cnode_inputs = CreateValueNode(cnode, graph);
1268 } else if (attr_input->isa<CNode>()) {
1269 // cnode ia a call (partial/switch/switch_layer)
1270 // 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
1271 // 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
1272 cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
1273 if (cnode_inputs.empty()) {
1274 MS_LOG(ERROR) << "Create switch or partial failed, cnode:" << cnode->DebugString();
1275 return nullptr;
1276 }
1277 } else {
1278 // get primitive of old node
1279 auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
1280 MS_EXCEPTION_IF_NULL(prim);
1281 // push attr to inputs[0] of new cnode
1282 cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
1283 }
1284 // handle inputs of cnode except primitive
1285 CreateCNodeInputs(cnode, graph, &cnode_inputs);
1286 TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
1287 auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
1288 MS_EXCEPTION_IF_NULL(new_cnode);
1289 // if the cnode is call switch, remove call
1290 if (new_cnode->size() > 1) {
1291 auto first_input = new_cnode->input(kFirstDataInputIndex);
1292 MS_EXCEPTION_IF_NULL(first_input);
1293 if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1294 common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
1295 new_cnode = first_input->cast<CNodePtr>();
1296 }
1297 if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1298 common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
1299 auto abstract = cnode->abstract();
1300 new_cnode = first_input->cast<CNodePtr>();
1301 new_cnode->set_abstract(abstract);
1302 }
1303 }
1304 return new_cnode;
1305 }
1306
CreateSwitchInput(const CNodePtr & cnode,const AnfNodePtr & node_input,KernelGraph * graph)1307 CNodePtr KernelGraphMgr::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
1308 MS_EXCEPTION_IF_NULL(node_input);
1309 MS_EXCEPTION_IF_NULL(graph);
1310 // switch input generalizes partial
1311 std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
1312 if (common::AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
1313 auto backend_node = graph->GetBackendAnfByFrontAnf(node_input);
1314 MS_EXCEPTION_IF_NULL(backend_node);
1315 return backend_node->cast<CNodePtr>();
1316 } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
1317 (void)(partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)));
1318 } else {
1319 KernelGraphPtr kernel_graph = NewKernelGraph();
1320 MS_EXCEPTION_IF_NULL(kernel_graph);
1321 auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
1322 MS_EXCEPTION_IF_NULL(parameter);
1323 MS_EXCEPTION_IF_NULL(cnode);
1324 parameter->set_abstract(cnode->abstract());
1325 auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
1326 auto return_node = kernel_graph->NewCNode({primitive, parameter});
1327 MS_EXCEPTION_IF_NULL(return_node);
1328 return_node->set_abstract(cnode->abstract());
1329 kernel_graph->set_return(return_node);
1330 (void)(partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)));
1331 (void)(partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)));
1332 }
1333 auto partial_node = graph->NewCNode(partial_inputs);
1334 return partial_node;
1335 }
1336
CacheKernelGraph(const KernelGraphPtr & kg)1337 void KernelGraphMgr::CacheKernelGraph(const KernelGraphPtr &kg) {
1338 MS_EXCEPTION_IF_NULL(kg);
1339 auto &context = CompileCacheContext::GetInstance();
1340 auto fg = context.FrontGraph();
1341 if (!fg) {
1342 MS_LOG(EXCEPTION) << "The frontend graph to be cached is null";
1343 }
1344 if (!kg) {
1345 MS_LOG(EXCEPTION) << "The backend graph to be cached is null";
1346 }
1347 MS_LOG(INFO) << "Begin to cache kernel graph " << kg->ToString();
1348 std::set<KernelGraphPtr> visit;
1349 std::set<KernelGraphPtr> child_graphs;
1350 GetAllChildGraph(kg, &visit, &child_graphs);
1351 #ifdef ENABLE_DUMP_IR
1352 auto ms_context = MsContext::GetInstance();
1353 MS_EXCEPTION_IF_NULL(ms_context);
1354 if (ms_context->CanDump(kIntroductory)) {
1355 DumpIR("compile_cache_" + kg->ToString() + ".ir", kg);
1356 for (auto &graph : child_graphs) {
1357 DumpIR("compile_cache_" + graph->ToString() + ".ir", graph);
1358 }
1359 }
1360 #endif
1361 std::vector<AnfNodePtr> temp_nodes;
1362 std::map<KernelGraphPtr, std::vector<AnfNodePtr>> isolated_nodes_map;
1363 HandleParamExistCorrespondFrontendParam(kg);
1364 GetIsolatedNodes(kg, &temp_nodes);
1365 isolated_nodes_map[kg] = temp_nodes;
1366 auto cache_path = context.GetBackendGraphCachePath(fg);
1367 const std::string &mindir_path = cache_path + kMindIrSuffix;
1368 for (const auto &graph : child_graphs) {
1369 temp_nodes.clear();
1370 HandleParamExistCorrespondFrontendParam(graph);
1371 GetIsolatedNodes(graph, &temp_nodes);
1372 isolated_nodes_map[graph] = temp_nodes;
1373 }
1374 std::vector<AnfNodePtr> isolated_nodes;
1375 for (const auto &iter : isolated_nodes_map) {
1376 const auto &nodes = iter.second;
1377 (void)(isolated_nodes.insert(isolated_nodes.end(), nodes.begin(), nodes.end()));
1378 }
1379 std::vector<FuncGraphPtr> child_graphs_for_dump(child_graphs.begin(), child_graphs.end());
1380 if (!DumpBinaryProto(kg, child_graphs_for_dump, isolated_nodes, mindir_path)) {
1381 MS_LOG(ERROR) << "Failed to cache kernel graph to mindir: " << fg->ToString();
1382 return;
1383 }
1384 (void)(std::for_each(front_backend_graph_map_.begin(), front_backend_graph_map_.end(),
1385 [&context](const auto &fb) { context.AddBackendGraphToFrontendGraph(fb.second, fb.first); }));
1386 const std::string &json_path = cache_path + kJsonSuffix;
1387 if (!DumpKernelGraphJson(kg, child_graphs, isolated_nodes_map, json_path)) {
1388 MS_LOG(ERROR) << "Failed to cache kernel graph to json.";
1389 return;
1390 }
1391 context.Clear();
1392 MS_LOG(INFO) << "Cache kernel graph " << kg->ToString() << " success.";
1393 }
1394
CreateCallSwitchInputs(const CNodePtr & cnode,KernelGraph * graph) const1395 std::vector<AnfNodePtr> KernelGraphMgr::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) const {
1396 MS_EXCEPTION_IF_NULL(cnode);
1397 MS_EXCEPTION_IF_NULL(graph);
1398 std::vector<AnfNodePtr> cnode_inputs = {
1399 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1400 auto attr_input = cnode->input(kAnfPrimitiveIndex);
1401 MS_EXCEPTION_IF_NULL(attr_input);
1402 auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
1403 MS_EXCEPTION_IF_NULL(cnode_input);
1404 auto switch_cnode = cnode_input->cast<CNodePtr>();
1405 MS_EXCEPTION_IF_NULL(switch_cnode);
1406 if (cnode->size() <= 1) {
1407 cnode_inputs = switch_cnode->inputs();
1408 return cnode_inputs;
1409 }
1410 std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
1411 switch_cnode->input(kFirstDataInputIndex)};
1412 for (size_t index = kSwitchTrueBranchIndex; index < switch_cnode->size(); index++) {
1413 auto node = switch_cnode->input(index);
1414 MS_EXCEPTION_IF_NULL(node);
1415 // there is real input in call, should put it to true and false branch in switch
1416 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
1417 auto partial_node = node->cast<CNodePtr>();
1418 MS_EXCEPTION_IF_NULL(partial_node);
1419 std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
1420 // Put all call args at the end of partial inputs.
1421 for (size_t i = kFirstDataInputIndex; i < cnode->size(); ++i) {
1422 (void)partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(i)));
1423 }
1424 auto new_partial = graph->NewCNode(partial_inputs);
1425 (void)switch_inputs.emplace_back(new_partial);
1426 }
1427 }
1428 if (switch_inputs.size() < kSwitchInputSize) {
1429 MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
1430 }
1431 auto switch_node = graph->NewCNode(switch_inputs);
1432 (void)cnode_inputs.emplace_back(switch_node);
1433 return cnode_inputs;
1434 }
1435
ProcessNodeRetFunc(const CNodePtr & cnode,KernelGraph * graph,const std::vector<AnfNodePtr> & real_inputs)1436 void KernelGraphMgr::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
1437 const std::vector<AnfNodePtr> &real_inputs) {
1438 MS_EXCEPTION_IF_NULL(cnode);
1439 // func1 =switch(branch1, branch2)
1440 // func2 = func1(param1)
1441 // out = func2(param2)
1442 // process the last cnode(func2), not func1 which abstract is AbstractFunction
1443 if (cnode->abstract()->isa<abstract::AbstractFunction>()) {
1444 return;
1445 }
1446 MS_EXCEPTION_IF_NULL(graph);
1447 auto ret = graph->get_return();
1448 MS_EXCEPTION_IF_NULL(ret);
1449 auto return_input = ret->input(kFirstDataInputIndex);
1450 // return node is a function
1451 std::vector<AnfNodePtr> call_inputs = {
1452 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1453 if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
1454 auto return_input_cnode = return_input->cast<CNodePtr>();
1455 MS_EXCEPTION_IF_NULL(return_input_cnode);
1456 auto partial_inputs = return_input_cnode->inputs();
1457 (void)call_inputs.insert(call_inputs.cend(), partial_inputs.cbegin() + kFirstDataInputIndex, partial_inputs.cend());
1458 } else if (IsValueNode<KernelGraph>(return_input)) { // return node is kernel graph
1459 (void)(call_inputs.emplace_back(return_input));
1460 } else { // return node is value node
1461 KernelGraphPtr kernel_graph = NewKernelGraph();
1462 MS_EXCEPTION_IF_NULL(kernel_graph);
1463 auto valid_inputs = kernel_graph->MutableValidInputs();
1464 MS_EXCEPTION_IF_NULL(valid_inputs);
1465 auto graph_inputs = kernel_graph->MutableInputs();
1466 MS_EXCEPTION_IF_NULL(graph_inputs);
1467 std::vector<AnfNodePtr> cnode_inputs = {return_input};
1468 for (auto &real_input : real_inputs) {
1469 auto new_parameter = kernel_graph->NewParameter(real_input->abstract());
1470 valid_inputs->push_back(true);
1471 graph_inputs->push_back(new_parameter);
1472 cnode_inputs.push_back(new_parameter);
1473 }
1474 auto new_cnode = kernel_graph->NewCNode(cnode_inputs);
1475 new_cnode->set_abstract(cnode->abstract());
1476 std::vector<AnfNodePtr> return_inputs = {
1477 kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode};
1478 auto return_node = kernel_graph->NewCNode(return_inputs);
1479 return_node->set_abstract(cnode->abstract());
1480 kernel_graph->set_return(return_node);
1481 call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph));
1482 }
1483
1484 // new call node inputs
1485 for (auto &input_node : real_inputs) {
1486 auto parameter_for_input = CreateNewParameterFromCNode(input_node, graph);
1487 (void)(call_inputs.emplace_back(parameter_for_input));
1488 }
1489
1490 auto call_node = graph->NewCNode(call_inputs);
1491 MS_EXCEPTION_IF_NULL(call_node);
1492 call_node->set_abstract(cnode->abstract());
1493 // update return input
1494 ret->set_input(kFirstDataInputIndex, call_node);
1495 }
1496
CreateCallSwitchLayerInputs(const CNodePtr & cnode,KernelGraph * graph)1497 std::vector<AnfNodePtr> KernelGraphMgr::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
1498 MS_EXCEPTION_IF_NULL(cnode);
1499 MS_EXCEPTION_IF_NULL(graph);
1500 std::vector<AnfNodePtr> cnode_inputs = {
1501 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1502 auto attr_input = cnode->input(kAnfPrimitiveIndex);
1503 MS_EXCEPTION_IF_NULL(attr_input);
1504 auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
1505 MS_EXCEPTION_IF_NULL(cnode_input);
1506 auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
1507 MS_EXCEPTION_IF_NULL(switch_layer_cnode);
1508 std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
1509 switch_layer_cnode->input(kFirstDataInputIndex)};
1510 auto make_tuple_node = switch_layer_cnode->input(kSwitchLayerBranchesIndex);
1511 MS_EXCEPTION_IF_NULL(make_tuple_node);
1512 auto node = make_tuple_node->cast<CNodePtr>();
1513 MS_EXCEPTION_IF_NULL(node);
1514 auto make_tuple_inputs = node->inputs();
1515 // there are real inputs in call, should put it to make_tuple in switch_layer
1516 std::vector<AnfNodePtr> real_inputs;
1517 for (size_t idx = kFirstDataInputIndex; idx < cnode->size(); ++idx) {
1518 (void)(real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx))));
1519 }
1520 std::vector<AnfNodePtr> new_make_tuple_inputs = {
1521 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
1522 for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
1523 auto partial_idx = make_tuple_inputs[idx];
1524 MS_EXCEPTION_IF_NULL(cnode->abstract());
1525 std::vector<AnfNodePtr> new_partial_inputs;
1526 KernelGraphPtr partial_kernel_graph;
1527 // switch_layer node input is partial cnode
1528 if (common::AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
1529 auto partial_node = partial_idx->cast<CNodePtr>();
1530 MS_EXCEPTION_IF_NULL(partial_node);
1531 auto partial_input = partial_node->input(kFirstDataInputIndex);
1532 partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
1533 new_partial_inputs = partial_node->inputs();
1534 } else if (IsValueNode<KernelGraph>(partial_idx)) { // switch_layer node input is kernel graph value node
1535 (void)(new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))));
1536 (void)(new_partial_inputs.emplace_back(partial_idx));
1537 partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx);
1538 }
1539 // when branch in swich_layer return function
1540 MS_EXCEPTION_IF_NULL(partial_kernel_graph);
1541 auto ret = partial_kernel_graph->get_return();
1542 MS_EXCEPTION_IF_NULL(ret);
1543 auto return_input = ret->input(kFirstDataInputIndex);
1544 if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
1545 ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
1546 }
1547 // partial node add input args
1548 (void)new_partial_inputs.insert(new_partial_inputs.cend(), real_inputs.cbegin(), real_inputs.cend());
1549 // create new partial node
1550 auto new_partial = graph->NewCNode(new_partial_inputs);
1551 (void)(new_make_tuple_inputs.emplace_back(new_partial));
1552 }
1553 auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
1554 auto abstract = make_tuple_node->abstract();
1555 if (abstract == nullptr) {
1556 abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
1557 }
1558 new_make_tuple->set_abstract(abstract);
1559 (void)(switch_layer_inputs.emplace_back(new_make_tuple));
1560 auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
1561 (void)(cnode_inputs.emplace_back(new_switch_layer));
1562 return cnode_inputs;
1563 }
1564
CreateSwitchOrPartialNode(const CNodePtr & cnode,KernelGraph * graph)1565 std::vector<AnfNodePtr> KernelGraphMgr::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
1566 MS_EXCEPTION_IF_NULL(cnode);
1567 MS_EXCEPTION_IF_NULL(graph);
1568 // create primitive of cnode:call(partial or switch or switch_layer)
1569 std::vector<AnfNodePtr> cnode_inputs = {
1570 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1571 auto attr_input = cnode->input(kAnfPrimitiveIndex);
1572 MS_EXCEPTION_IF_NULL(attr_input);
1573 auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
1574 if (cnode_input == nullptr) {
1575 MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created.";
1576 return {};
1577 }
1578 // if the node is partial, insert the inputs of partial to the call
1579 if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
1580 auto partial_node = attr_input->cast<CNodePtr>();
1581 MS_EXCEPTION_IF_NULL(partial_node);
1582 auto partial_inputs = partial_node->inputs();
1583 (void)std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
1584 std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
1585 MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
1586 return graph->GetBackendAnfByFrontAnf(node);
1587 });
1588 return cnode_inputs;
1589 } else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
1590 return CreateCallSwitchInputs(cnode, graph);
1591 } else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
1592 return CreateCallSwitchLayerInputs(cnode, graph);
1593 } else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
1594 // only support tuple get item from a call subgraph output
1595 auto tuple_get_node = cnode_input->cast<CNodePtr>();
1596 MS_EXCEPTION_IF_NULL(tuple_get_node);
1597 auto get_from_node = tuple_get_node->input(kFirstIndex);
1598 MS_EXCEPTION_IF_NULL(get_from_node);
1599 if (common::AnfAlgo::CheckPrimitiveType(get_from_node, prim::kPrimCall)) {
1600 auto call_node = get_from_node->cast<CNodePtr>();
1601 MS_EXCEPTION_IF_NULL(call_node);
1602 auto call_graph = call_node->input(kFirstIndex);
1603 auto sub_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(call_graph);
1604 MS_EXCEPTION_IF_NULL(sub_kernel_graph);
1605 if (kernel_graph_partial_map_.find(sub_kernel_graph.get()) == kernel_graph_partial_map_.end()) {
1606 MS_LOG(EXCEPTION) << "Kernel Graph: " << sub_kernel_graph->ToString()
1607 << " has not a return value is a Partial Func.";
1608 }
1609 auto tuple_get_idx = common::AnfAlgo::GetTupleGetItemOutIndex(tuple_get_node);
1610 auto info = kernel_graph_partial_map_[sub_kernel_graph.get()];
1611 call_node->set_abstract(info.abstract);
1612 (void)cnode_inputs.emplace_back(info.sub_graph);
1613 auto context = MsContext::GetInstance();
1614 MS_EXCEPTION_IF_NULL(context);
1615 if (context->CellReuseLevel() == CellReuseLevel::kLazyInline) {
1616 // call_graph and info.sub_graph need inline when cell reuse.
1617 sub_kernel_graph->set_need_inline(true);
1618 auto partial_sub_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(info.sub_graph);
1619 MS_EXCEPTION_IF_NULL(partial_sub_graph);
1620 partial_sub_graph->set_need_inline(true);
1621 MS_LOG(INFO) << "Inline graph " << sub_kernel_graph->graph_id() << " and graph "
1622 << partial_sub_graph->graph_id();
1623 }
1624 MS_LOG(INFO) << "Use cell reuse: " << sub_kernel_graph->graph_id();
1625 if (info.param_begin != tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)) {
1626 MS_LOG(EXCEPTION) << "Call param is not a graph, the TupleGetItem index: " << tuple_get_idx
1627 << ", the partial graph index: " << info.param_begin
1628 << ", need idx: " << tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)
1629 << ", call graph: " << call_graph->fullname_with_scope();
1630 }
1631 for (size_t i = info.param_begin; i < info.param_end; i++) {
1632 auto idx = NewValueNode(SizeToLong(i));
1633 MS_EXCEPTION_IF_NULL(idx);
1634 auto imm = std::make_shared<Int64Imm>(i);
1635 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
1636 auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), call_node, idx});
1637 std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(call_node, i)};
1638 auto shapes = {common::AnfAlgo::GetOutputInferShape(call_node, i)};
1639 common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
1640 (void)cnode_inputs.emplace_back(getitem);
1641 }
1642 return cnode_inputs;
1643 }
1644 }
1645 MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
1646 << "must be partial or switch or switch_layer.";
1647 return {};
1648 }
1649
CreateValueNode(const CNodePtr & cnode,KernelGraph * graph)1650 std::vector<AnfNodePtr> KernelGraphMgr::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
1651 MS_EXCEPTION_IF_NULL(cnode);
1652 MS_EXCEPTION_IF_NULL(graph);
1653 std::vector<AnfNodePtr> cnode_inputs;
1654 auto attr_input = cnode->input(kAnfPrimitiveIndex);
1655 MS_EXCEPTION_IF_NULL(attr_input);
1656 if (common::AnfAlgo::IsGraphKernel(cnode)) {
1657 auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
1658 MS_EXCEPTION_IF_NULL(fg);
1659 auto new_fg = BasicClone(fg);
1660 cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
1661 } else {
1662 // create primitive of cnode:call
1663 cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1664 // create a ValueNode<KernelGraph> as input of cnode:call
1665 if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
1666 (void)(cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)));
1667 } else {
1668 auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
1669 if (new_value_node != nullptr) {
1670 (void)(cnode_inputs.emplace_back(new_value_node));
1671 }
1672 }
1673 }
1674 return cnode_inputs;
1675 }
1676
CreateCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs)1677 void KernelGraphMgr::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
1678 std::vector<AnfNodePtr> *cnode_inputs) {
1679 MS_EXCEPTION_IF_NULL(cnode);
1680 MS_EXCEPTION_IF_NULL(graph);
1681 MS_EXCEPTION_IF_NULL(cnode_inputs);
1682 if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1683 (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
1684 for (size_t index = kSwitchTrueBranchIndex; index < cnode->size(); index++) {
1685 auto node_input = cnode->input(index);
1686 auto switch_input = CreateSwitchInput(cnode, node_input, graph);
1687 (void)cnode_inputs->emplace_back(switch_input);
1688 }
1689 } else {
1690 for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->size(); input_idx++) {
1691 auto anf = cnode->input(input_idx);
1692 MS_EXCEPTION_IF_NULL(anf);
1693 // anf has been created before
1694 if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
1695 (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
1696 continue;
1697 } else if (anf->isa<Parameter>()) {
1698 auto new_parameter = CreateNewParameterFromParameter(anf, graph);
1699 MS_EXCEPTION_IF_NULL(new_parameter);
1700 (void)cnode_inputs->emplace_back(new_parameter);
1701 graph->FrontBackendMapAdd(anf, new_parameter);
1702 continue;
1703 } else if (anf->isa<ValueNode>()) {
1704 auto new_value_node = CreateNewValueNode(anf, graph);
1705 MS_EXCEPTION_IF_NULL(new_value_node);
1706 (void)cnode_inputs->emplace_back(new_value_node);
1707 continue;
1708 }
1709 MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
1710 }
1711 }
1712 }
1713
CreateValueNodeKernelGraph(const AnfNodePtr & anf,KernelGraph * graph)1714 ValueNodePtr KernelGraphMgr::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
1715 MS_EXCEPTION_IF_NULL(anf);
1716 MS_EXCEPTION_IF_NULL(graph);
1717 auto value_node = anf->cast<ValueNodePtr>();
1718 MS_EXCEPTION_IF_NULL(value_node);
1719 auto sub_func_graph = common::AnfAlgo::GetValueNodeFuncGraph(anf);
1720 MS_EXCEPTION_IF_NULL(sub_func_graph);
1721 if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) {
1722 MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
1723 }
1724 auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph.get()];
1725
1726 ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
1727 MS_EXCEPTION_IF_NULL(new_value_node);
1728 new_value_node->set_abstract(value_node->abstract());
1729 // create new kernel_info of new value_node
1730 auto kernel_info = std::make_shared<device::KernelInfo>();
1731 MS_EXCEPTION_IF_NULL(kernel_info);
1732 new_value_node->set_kernel_info(kernel_info);
1733 // create kernel_build_info for new value node
1734 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
1735 MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
1736 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
1737 AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
1738
1739 graph->FrontBackendMapAdd(anf, new_value_node);
1740
1741 return new_value_node;
1742 }
1743
CreateNewParameter(const AnfNodePtr & anf,KernelGraph * graph) const1744 ParameterPtr KernelGraphMgr::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) const {
1745 MS_EXCEPTION_IF_NULL(anf);
1746 MS_EXCEPTION_IF_NULL(graph);
1747 if (!anf->isa<Parameter>()) {
1748 MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
1749 }
1750
1751 auto param_value = GetParamDefaultValue(anf);
1752 ParameterPtr new_parameter = nullptr;
1753 // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
1754 if (param_value != nullptr) {
1755 new_parameter = param_value->parameter();
1756 if (new_parameter == nullptr) {
1757 TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1758 new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1759 param_value->set_parameter(new_parameter);
1760 }
1761 } else {
1762 TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1763 new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1764 }
1765
1766 new_parameter->IncreaseUsedGraphCount();
1767
1768 return new_parameter;
1769 }
1770
FlattenTuple(const CNodePtr & node)1771 void KernelGraphMgr::FlattenTuple(const CNodePtr &node) {
1772 MS_EXCEPTION_IF_NULL(node);
1773 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
1774 auto call_graph = node->input(kFirstIndex);
1775 auto sub_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(call_graph);
1776 MS_EXCEPTION_IF_NULL(sub_kernel_graph);
1777 auto iter = kernel_graph_partial_map_.find(sub_kernel_graph.get());
1778 if (iter != kernel_graph_partial_map_.end() && iter->second.multi_tuple != 0) {
1779 (void)need_flatten_.insert(node);
1780 }
1781 } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
1782 auto input = node->input(kFirstIndex);
1783 auto get_idx = common::AnfAlgo::GetTupleGetItemOutIndex(node);
1784 if (need_flatten_.find(input) != need_flatten_.end() && get_idx == 0) {
1785 need_flatten_tuple_map_[node] = input;
1786 }
1787 }
1788 for (size_t i = 0; i < common::AnfAlgo::GetInputNum(node); i++) {
1789 auto input = common::AnfAlgo::GetInputNode(node, i);
1790 auto iter = need_flatten_tuple_map_.find(input);
1791 if (iter != need_flatten_tuple_map_.end()) {
1792 node->set_input(i + 1, iter->second);
1793 }
1794 }
1795 }
1796
CreateCNodeOfKernelGraph(const AnfNodePtr & node,KernelGraph * graph)1797 bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
1798 MS_EXCEPTION_IF_NULL(node);
1799 MS_EXCEPTION_IF_NULL(graph);
1800 auto cnode = node->cast<CNodePtr>();
1801 MS_EXCEPTION_IF_NULL(cnode);
1802 // create a new cnode object
1803 auto new_cnode = CreateNewCNode(cnode, graph);
1804 if (new_cnode == nullptr) {
1805 return false;
1806 }
1807 new_cnode->set_abstract(cnode->abstract());
1808 std::string fullname = cnode->fullname_with_scope();
1809 auto prim_input = cnode->input(kAnfPrimitiveIndex);
1810 // cnode is a call (partial/switch/switch_layer), full scope name is "1_2".
1811 // it is hard to analysis bug when it used as ge node name.
1812 if (!prim_input->isa<CNode>()) {
1813 new_cnode->set_fullname_with_scope(fullname);
1814 }
1815 new_cnode->set_scope(cnode->scope());
1816 if (!graph->is_dynamic_shape() && common::AnfAlgo::IsDynamicShape(new_cnode)) {
1817 graph->SetGraphDynamicAttr(true);
1818 }
1819 graph->FrontBackendMapAdd(node, new_cnode);
1820 SetReturnNode(new_cnode, graph);
1821 FlattenTuple(new_cnode);
1822 return true;
1823 }
1824
AddParameterToGraphInputs(const std::vector<AnfNodePtr> & parameters,KernelGraph * graph) const1825 void KernelGraphMgr::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph) const {
1826 MS_EXCEPTION_IF_NULL(graph);
1827 auto graph_inputs = graph->MutableInputs();
1828 MS_EXCEPTION_IF_NULL(graph_inputs);
1829 graph_inputs->clear();
1830 for (auto ¶meter : parameters) {
1831 MS_EXCEPTION_IF_NULL(parameter);
1832 auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
1833 if (backend_parameter == nullptr) {
1834 // for example "def f(x,y,z) {return x + y}", parameter z in unused
1835 auto new_parameter = CreateNewParameter(parameter, graph);
1836 graph_inputs->push_back(new_parameter);
1837 graph->FrontBackendMapAdd(parameter, new_parameter);
1838 MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
1839 continue;
1840 }
1841 graph_inputs->push_back(backend_parameter);
1842 }
1843 }
1844
1845 // 1. Convert the node to make_tuple if the node is a ValueNode<ValueTuple> and it's the input of 'return' node.
1846 // 2. Set the return of graph if node is "Return" node.
1847 // 3. If the return of graph has a Partial Func, should inline it in return value.
SetReturnNode(const AnfNodePtr & node,KernelGraph * graph)1848 void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
1849 MS_EXCEPTION_IF_NULL(graph);
1850 MS_EXCEPTION_IF_NULL(node);
1851 if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
1852 return;
1853 }
1854 constexpr auto kReturnInputIdx = 1;
1855 auto return_node = node->cast<CNodePtr>();
1856 MS_EXCEPTION_IF_NULL(return_node);
1857 graph->set_return(return_node);
1858 auto graph_output = return_node->input(kReturnInputIdx);
1859 MS_EXCEPTION_IF_NULL(graph_output);
1860
1861 // If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
1862 // match this pattern because that pass begin with output node but return node. So we add transform value tuple
1863 // to make_tuple here.
1864 if (common::AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
1865 return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
1866 }
1867
1868 // inline partial to call graph
1869 auto return_tuple = return_node->input(kReturnInputIdx);
1870 MS_EXCEPTION_IF_NULL(return_tuple);
1871 if (return_tuple->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(return_tuple, prim::kPrimMakeTuple)) {
1872 auto make_tuple = return_tuple->cast<CNodePtr>();
1873 MS_EXCEPTION_IF_NULL(make_tuple);
1874 size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
1875 // only support the last return node is a partial func now
1876 auto last_input_node = common::AnfAlgo::GetInputNode(make_tuple, tuple_input_num - 1);
1877 MS_EXCEPTION_IF_NULL(last_input_node);
1878 if (last_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(last_input_node, prim::kPrimPartial)) {
1879 size_t multi_tuple = 0;
1880 auto partial_node = last_input_node->cast<CNodePtr>();
1881 MS_EXCEPTION_IF_NULL(partial_node);
1882 size_t partial_input_num = common::AnfAlgo::GetInputTensorNum(partial_node);
1883 std::vector<AnfNodePtr> make_tuple_inputs;
1884 (void)make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1885 // skip last return node (is a partial)
1886 size_t param_begin = 0;
1887 for (size_t i = 0; i < tuple_input_num - 1; i++) {
1888 auto input = common::AnfAlgo::GetInputNode(make_tuple, i);
1889 MS_EXCEPTION_IF_NULL(input);
1890 auto node_abs = input->abstract();
1891 MS_EXCEPTION_IF_NULL(node_abs);
1892 if (node_abs->isa<abstract::AbstractSequence>()) {
1893 MS_EXCEPTION_IF_CHECK_FAIL(
1894 i == 0, "Input index: " + std::to_string(i) + " is a make tuple, input node: " + input->DebugString());
1895 MS_LOG(DEBUG) << "Flatten the make tuple, input node: " << input->DebugString()
1896 << ", output num: " << AnfUtils::GetOutputTensorNum(input);
1897 // flatten the make tuple
1898 for (size_t j = 0; j < AnfUtils::GetOutputTensorNum(input); j++) {
1899 auto idx = NewValueNode(SizeToLong(j));
1900 MS_EXCEPTION_IF_NULL(idx);
1901 auto imm = std::make_shared<Int64Imm>(j);
1902 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
1903 auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
1904 std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(input, j)};
1905 auto shapes = {common::AnfAlgo::GetOutputInferShape(input, j)};
1906 common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
1907 param_begin++;
1908 multi_tuple++;
1909 (void)make_tuple_inputs.emplace_back(getitem);
1910 }
1911 } else {
1912 param_begin++;
1913 (void)make_tuple_inputs.emplace_back(input);
1914 }
1915 }
1916 // skip partial graph
1917 for (size_t i = kFirstIndex; i < partial_input_num; i++) {
1918 (void)make_tuple_inputs.emplace_back(common::AnfAlgo::GetInputNode(partial_node, i));
1919 }
1920 auto g_output = graph->NewCNode(make_tuple_inputs);
1921 MS_EXCEPTION_IF_NULL(g_output);
1922 std::vector<AbstractBasePtr> abstract_list;
1923 for (size_t i = kFirstIndex; i < make_tuple_inputs.size(); ++i) {
1924 auto inputs_node = make_tuple_inputs[i];
1925 MS_EXCEPTION_IF_NULL(inputs_node);
1926 (void)abstract_list.emplace_back(inputs_node->abstract());
1927 }
1928 auto abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
1929 MS_EXCEPTION_IF_NULL(g_output);
1930 g_output->set_abstract(abstract);
1931 graph->set_output(g_output);
1932 kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0), param_begin,
1933 common::AnfAlgo::GetInputTensorNum(g_output), multi_tuple};
1934 }
1935 }
1936 }
1937
ConstructKernelGraph(const AnfNodePtrList & lst,const AnfNodePtrList & outputs,DeviceType device_target,bool common_opt,bool is_enable_zero_copy)1938 KernelGraphPtr KernelGraphMgr::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
1939 DeviceType device_target, bool common_opt,
1940 bool is_enable_zero_copy) {
1941 mindspore::HashMap<AnfNodePtr, AnfNodePtr> other_graph_cnode;
1942 std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
1943 auto graph = NewKernelGraph();
1944 MS_EXCEPTION_IF_NULL(graph);
1945 // Set the zero copy flag in subgraph sink mode.
1946 if (is_enable_zero_copy) {
1947 MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString();
1948 graph->set_flag(kFlagEnableZeroCopyInGraph, true);
1949 }
1950 MS_LOG(INFO) << "Create graph: " << graph->graph_id();
1951 graph->set_device_target(device_target);
1952 for (const auto &node : lst) {
1953 MS_EXCEPTION_IF_NULL(node);
1954 MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
1955 if (!node->isa<CNode>()) {
1956 MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
1957 }
1958 auto cnode = node->cast<CNodePtr>();
1959 MS_EXCEPTION_IF_NULL(cnode);
1960 // create a new cnode object
1961 auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
1962 MS_EXCEPTION_IF_NULL(new_cnode);
1963 if (IsOneOfPrimitiveCNode(new_cnode, {prim::kPrimCall, prim::kPrimPartial})) {
1964 auto fn = new_cnode->input(kIndexOne);
1965 MS_EXCEPTION_IF_NULL(fn);
1966 auto child_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(fn);
1967 MS_EXCEPTION_IF_NULL(child_kernel_graph);
1968 child_graph_order.push_back(std::weak_ptr<KernelGraph>(child_kernel_graph));
1969 }
1970
1971 new_cnode->set_abstract(cnode->abstract());
1972 new_cnode->set_scope(cnode->scope());
1973 new_cnode->set_attrs(cnode->attrs());
1974
1975 if (new_cnode->HasAttr(kAttrReplaceRealKernelInBackend)) {
1976 MS_LOG(DEBUG) << "Erase flag for node: " << new_cnode->DebugString();
1977 new_cnode->EraseAttr(kAttrReplaceRealKernelInBackend);
1978 }
1979
1980 if (cnode->user_data<pynative::JitCallGraph>()) {
1981 new_cnode->set_user_data(cnode->user_data<pynative::JitCallGraph>());
1982 }
1983 // record map relations between anf from ME and new anf node used in backend
1984 graph->FrontBackendMapAdd(node, new_cnode);
1985 if (!graph->is_dynamic_shape() && common::AnfAlgo::IsDynamicShape(new_cnode)) {
1986 graph->SetGraphDynamicAttr(true);
1987 }
1988 }
1989 // add a make_tuple at the end of graph as output
1990 graph->set_child_graph_order(child_graph_order);
1991 graph->set_output(ConstructOutput(outputs, graph));
1992 graph->SetExecOrderByDefault();
1993
1994 #ifndef ENABLE_SECURITY
1995 if (ExistSummaryNode(graph.get())) {
1996 graph->set_summary_node_exist(true);
1997 }
1998 #endif
1999 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
2000 if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
2001 FuncGraphManagerPtr manager = MakeManager({graph}, false);
2002 if (manager) {
2003 manager->AddFuncGraph(graph);
2004 graph->set_manager(manager);
2005 }
2006 UnifyMindIR(graph);
2007 graph->UpdateGraphAquireGilAttr();
2008 if (common_opt) {
2009 opt::BackendCommonOptimization(graph);
2010 }
2011 graph->SetInputNodes();
2012 SetInputNodeUsage(graph, manager);
2013 graph->SetOptimizerFlag();
2014 }
2015 graph->set_parameters(graph->inputs());
2016 return graph;
2017 }
2018
ConstructKernelGraph(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph,DeviceType device_target)2019 std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGraphPtr &func_graph,
2020 std::vector<KernelGraphPtr> *all_out_graph,
2021 DeviceType device_target) {
2022 auto graph = NewKernelGraph();
2023 front_backend_graph_map_[func_graph.get()] = graph;
2024 ConstructKernelGraphInner(func_graph, all_out_graph, device_target, graph);
2025 return graph;
2026 }
2027
ConstructPackKernelGraph(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph,DeviceType device_target)2028 std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructPackKernelGraph(const FuncGraphPtr &func_graph,
2029 std::vector<KernelGraphPtr> *all_out_graph,
2030 DeviceType device_target) {
2031 auto graph = NewPynativeKernelGraph();
2032 ConstructKernelGraphInner(func_graph, all_out_graph, device_target, graph);
2033 return graph;
2034 }
2035
ConstructKernelGraphInner(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph,DeviceType device_target,const KernelGraphPtr & graph)2036 void KernelGraphMgr::ConstructKernelGraphInner(const FuncGraphPtr &func_graph,
2037 std::vector<KernelGraphPtr> *all_out_graph, DeviceType device_target,
2038 const KernelGraphPtr &graph) {
2039 MS_EXCEPTION_IF_NULL(func_graph);
2040 MS_EXCEPTION_IF_NULL(all_out_graph);
2041 auto node_list = TopoSort(func_graph->get_return());
2042 MS_EXCEPTION_IF_NULL(graph);
2043 auto context = MsContext::GetInstance();
2044 MS_EXCEPTION_IF_NULL(context);
2045 if (func_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline) {
2046 MS_LOG(INFO) << "Need backend inline: " << graph->graph_id();
2047 graph->set_need_inline(true);
2048 }
2049 MS_LOG(INFO) << "Create graph: " << graph->graph_id();
2050 graph->set_device_target(device_target);
2051 // Create parameter
2052 for (const auto &node : func_graph->parameters()) {
2053 MS_EXCEPTION_IF_NULL(node);
2054 MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
2055 auto graph_inputs = graph->MutableInputs();
2056 MS_EXCEPTION_IF_NULL(graph_inputs);
2057 auto new_parameter = CreateNewParameter(node, graph.get());
2058 graph_inputs->push_back(new_parameter);
2059 graph->FrontBackendMapAdd(node, new_parameter);
2060 }
2061
2062 std::vector<ParameterPtr> added_parameters;
2063 std::vector<std::weak_ptr<KernelGraph>> child_kernel_graphs;
2064 for (const auto &node : node_list) {
2065 MS_EXCEPTION_IF_NULL(node);
2066 if (node->isa<Parameter>()) {
2067 continue;
2068 }
2069 MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
2070 // Create value node
2071 if (node->isa<ValueNode>()) {
2072 if (NeedConvertValueNodeToParameter(node)) {
2073 ConvertValueNodeToParameter(graph, node, &added_parameters);
2074 continue;
2075 }
2076 // Create common value node
2077 if (!IsValueNode<FuncGraph>(node)) {
2078 (void)CreateNewValueNode(node, graph.get());
2079 continue;
2080 }
2081 // Create child kernel graph according ValueNode<FuncGraph>
2082 FuncGraphPtr child_graph = common::AnfAlgo::GetValueNodeFuncGraph(node);
2083 auto child_kernel_graph = front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()
2084 ? ConstructKernelGraph(child_graph, all_out_graph, device_target)
2085 : front_backend_graph_map_[child_graph.get()];
2086 (void)child_kernel_graphs.emplace_back(std::weak_ptr<KernelGraph>(child_kernel_graph));
2087 (void)CreateValueNodeKernelGraph(node, graph.get());
2088 continue;
2089 }
2090 // Create cnode
2091 if (!CreateCNodeOfKernelGraph(node, graph.get())) {
2092 #ifdef ENABLE_DUMP_IR
2093 DumpIR("construct_kernel_graph_fail.ir", func_graph);
2094 #endif
2095 MS_LOG(EXCEPTION) << "Construct func graph " << func_graph->ToString() << " failed."
2096 << trace::DumpSourceLines(node);
2097 }
2098 }
2099
2100 AddParameterToGraphInputs(func_graph->parameters(), graph.get());
2101 // Add ValueNode-Parameter to graph.
2102 auto graph_inputs = graph->MutableInputs();
2103 MS_EXCEPTION_IF_NULL(graph_inputs);
2104 for (auto ¶meter : added_parameters) {
2105 (void)graph_inputs->emplace_back(parameter);
2106 }
2107
2108 FuncGraphManagerPtr manager = MakeManager({graph});
2109 graph->SetInputNodes();
2110 SetInputNodeUsage(graph, manager);
2111 graph->SetExecOrderByDefault();
2112
2113 #ifndef ENABLE_SECURITY
2114 if (ExistSummaryNode(graph.get())) {
2115 graph->set_summary_node_exist(true);
2116 }
2117 #endif
2118
2119 all_out_graph->push_back(graph);
2120 graph->set_parameters(graph->inputs());
2121 graph->set_child_graph_order(child_kernel_graphs);
2122 }
2123
HandleGraphInputsOutputs(const nlohmann::json & graph_json,KernelGraph * graph)2124 void HandleGraphInputsOutputs(const nlohmann::json &graph_json, KernelGraph *graph) {
2125 MS_EXCEPTION_IF_NULL(graph);
2126 auto &context = CompileCacheContext::GetInstance();
2127 if (graph_json.contains(kInputs)) {
2128 const auto &inputs_json = graph_json[kInputs];
2129 auto mutable_inputs = graph->MutableInputs();
2130 for (const auto &name : inputs_json) {
2131 AnfNodePtr node = context.FindBackNodeByBackName(name);
2132 MS_EXCEPTION_IF_NULL(node);
2133 context.InsertBackNameToBackNode(name, node);
2134 mutable_inputs->push_back(node);
2135 }
2136 }
2137 if (graph_json.contains(kParameters)) {
2138 const auto ¶meters_json = graph_json[kParameters];
2139 std::vector<AnfNodePtr> parameters;
2140 for (const auto &name : parameters_json) {
2141 auto node = context.FindBackNodeByBackName(name);
2142 MS_EXCEPTION_IF_NULL(node);
2143 parameters.push_back(node);
2144 }
2145 graph->set_parameters(parameters);
2146 }
2147
2148 if (graph_json.contains(kValidInputs)) {
2149 const auto &valid_inputs_json = graph_json[kValidInputs];
2150 auto mutable_valid_inputs = graph->MutableValidInputs();
2151 std::vector<bool> valid_inputs;
2152 (void)(std::transform(valid_inputs_json.begin(), valid_inputs_json.end(), std::back_inserter(valid_inputs),
2153 [](const auto &val) { return val; }));
2154 *mutable_valid_inputs = valid_inputs;
2155 }
2156 const auto &front_graph_name = graph_json[kCorrespondFrontendGraph];
2157 const auto &front_graph_node = context.FindFrontNodeByFrontName(front_graph_name);
2158 FuncGraphPtr front_graph = GetValueNode<FuncGraphPtr>(front_graph_node);
2159 MS_EXCEPTION_IF_NULL(front_graph);
2160 graph->FrontBackendMapAdd(front_graph->get_return(), graph->get_return());
2161 }
2162
HandleGraphSimpleAttr(const nlohmann::json & graph_json,KernelGraph * graph)2163 void HandleGraphSimpleAttr(const nlohmann::json &graph_json, KernelGraph *graph) {
2164 MS_EXCEPTION_IF_NULL(graph);
2165 MS_LOG(INFO) << "Handle graph " << graph->ToString() << " simple attr.";
2166 auto &context = CompileCacheContext::GetInstance();
2167 graph->set_run_mode(graph_json[kRunMode]);
2168 graph->set_is_loop_count_sink(graph_json[kIsLoopCountSink]);
2169 graph->SetGraphDynamicAttr(graph_json[kIsDynamicShape]);
2170 graph->set_device_target(graph_json[kDeviceTarget]);
2171 graph->set_root_graph_id(graph_json[kRootGraphId]);
2172 graph->set_executable(graph_json[kExecutable]);
2173 graph->set_recursive_call(graph_json[kHasRecursiveCall]);
2174 graph->set_need_inline(graph_json[kNeedInline]);
2175 graph->set_is_need_gil(graph_json[kIsNeedGil]);
2176 graph->set_is_from_single_op(graph_json[kIsFromSingleOp]);
2177 graph->set_subgraph_multi_call(graph_json[kHasSubgraphMultiCall]);
2178 graph->set_label_num(graph_json[kLabelNum]);
2179 #ifndef ENABLE_SECURITY
2180 // set summary_node of graph
2181 graph->set_summary_node_exist(graph_json[kSummaryNodeExist]);
2182 #endif
2183 if (graph_json.contains(kStartLabel)) {
2184 auto start_label = context.FindBackNodeByBackName(graph_json[kStartLabel]);
2185 if (start_label) {
2186 auto cstart_label = start_label->cast<CNodePtr>();
2187 MS_EXCEPTION_IF_NULL(cstart_label);
2188 graph->set_start_label(cstart_label);
2189 }
2190 }
2191 if (graph_json.contains(kEndGoto)) {
2192 auto end_goto = context.FindBackNodeByBackName(graph_json[kEndGoto]);
2193 if (end_goto) {
2194 auto cend_goto = end_goto->cast<CNodePtr>();
2195 MS_EXCEPTION_IF_NULL(cend_goto);
2196 graph->set_end_goto(cend_goto);
2197 }
2198 }
2199 if (graph_json.contains(kParameterUniqueNameToName)) {
2200 const auto &unique_name_to_name_json = graph_json[kParameterUniqueNameToName];
2201 for (const auto &[unique_name, name] : unique_name_to_name_json.items()) {
2202 auto node = context.FindBackNodeByBackName(unique_name);
2203 MS_EXCEPTION_IF_NULL(node);
2204 auto param = node->cast<ParameterPtr>();
2205 MS_EXCEPTION_IF_NULL(param);
2206 param->set_name(name);
2207 }
2208 }
2209 MS_LOG(INFO) << "Handle graph " << graph->ToString() << " simple attr success.";
2210 }
2211
HandleAttrAboutOtherGraph(const mindspore::HashMap<GraphId,std::shared_ptr<KernelGraph>> & graphs,const nlohmann::json & graph_json,KernelGraph * graph)2212 void HandleAttrAboutOtherGraph(const mindspore::HashMap<GraphId, std::shared_ptr<KernelGraph>> &graphs,
2213 const nlohmann::json &graph_json, KernelGraph *graph) {
2214 MS_EXCEPTION_IF_NULL(graph);
2215 auto &context = CompileCacheContext::GetInstance();
2216 if (graph_json.contains(kPreGraphs)) {
2217 const auto &pre_graphs_json = graph_json[kPreGraphs];
2218 for (const auto &iter : pre_graphs_json) {
2219 auto pre_graph = graphs.at(iter);
2220 MS_EXCEPTION_IF_NULL(pre_graph);
2221 graph->AddPreGraph(pre_graph);
2222 }
2223 }
2224 if (graph_json.contains(kPostGraphs)) {
2225 const auto &post_graphs_json = graph_json[kPostGraphs];
2226 for (const auto &iter : post_graphs_json) {
2227 auto post_graph = graphs.at(iter);
2228 MS_EXCEPTION_IF_NULL(post_graph);
2229 graph->AddPostGraph(post_graph);
2230 }
2231 }
2232 if (graph_json.contains(kChildGraphResult)) {
2233 const auto &child_graph_result_json = graph_json[kChildGraphResult];
2234 for (const auto &iter : child_graph_result_json) {
2235 auto node = context.FindBackNodeByBackName(iter);
2236 MS_EXCEPTION_IF_NULL(node);
2237 graph->AddChildGraphResult(node);
2238 }
2239 }
2240 if (graph_json.contains(kChildGraphOrder)) {
2241 const auto &child_graph_order_json = graph_json[kChildGraphOrder];
2242 std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
2243 for (const auto &iter : child_graph_order_json) {
2244 auto child_graph = graphs.at(iter);
2245 MS_EXCEPTION_IF_NULL(child_graph);
2246 child_graph_order.push_back(std::weak_ptr<KernelGraph>(child_graph));
2247 }
2248 graph->set_child_graph_order(child_graph_order);
2249 }
2250 }
2251
HandleGraphComplexAttr(const mindspore::HashMap<GraphId,std::shared_ptr<KernelGraph>> & graphs,const nlohmann::json & graph_json,KernelGraph * graph)2252 void HandleGraphComplexAttr(const mindspore::HashMap<GraphId, std::shared_ptr<KernelGraph>> &graphs,
2253 const nlohmann::json &graph_json, KernelGraph *graph) {
2254 MS_EXCEPTION_IF_NULL(graph);
2255 MS_LOG(INFO) << "Handle graph " << graph->ToString() << " complex attr.";
2256 auto &context = CompileCacheContext::GetInstance();
2257 std::vector<CNodePtr> execution_order;
2258 const auto &execution_order_json = graph_json[kExecutionOrder];
2259 for (const auto &order : execution_order_json) {
2260 auto node = context.FindBackNodeByBackName(order);
2261 MS_EXCEPTION_IF_NULL(node);
2262 execution_order.push_back(node->cast<CNodePtr>());
2263 }
2264 graph->set_execution_order(execution_order);
2265 if (graph_json.contains(kCommSubGraphIds)) {
2266 const auto &comm_sub_grpah_ids_json = graph_json[kCommSubGraphIds];
2267 for (const auto &iter : comm_sub_grpah_ids_json) {
2268 graph->RecordNewCommSubGraphId(iter);
2269 }
2270 }
2271 HandleAttrAboutOtherGraph(graphs, graph_json, graph);
2272 if (graph_json.contains(kInternalParameterToFrontNode)) {
2273 const auto &internal_parameter_to_front_node_json = graph_json[kInternalParameterToFrontNode];
2274 HashMap<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node;
2275 for (const auto &iter : internal_parameter_to_front_node_json) {
2276 const auto &back_name = iter.at(0);
2277 const auto &front_name = iter.at(kIndexOne);
2278 const auto &index = iter.at(kIndexTwo);
2279 auto back_node = context.FindBackNodeByBackName(back_name);
2280 MS_EXCEPTION_IF_NULL(back_node);
2281 auto front_node = context.FindFrontNodeByFrontName(front_name);
2282 MS_EXCEPTION_IF_NULL(front_node);
2283 internal_parameter_to_front_node[back_node] = AnfWithOutIndex(front_node, index);
2284 }
2285 graph->SetInternalParameterToFrontNodeMap(internal_parameter_to_front_node);
2286 }
2287 if (graph_json.contains(kRefInOutMap)) {
2288 const auto &ref_in_out_map_json = graph_json[kRefInOutMap];
2289 for (const auto &iter : ref_in_out_map_json) {
2290 const auto &first_name = iter.at(0);
2291 const auto &first_index = iter.at(kIndexOne);
2292 const auto &second_name = iter.at(kIndexTwo);
2293 const auto &second_index = iter.at(kIndexThree);
2294 auto first_node = context.FindBackNodeByBackName(first_name);
2295 MS_EXCEPTION_IF_NULL(first_node);
2296 auto second_node = context.FindBackNodeByBackName(second_name);
2297 MS_EXCEPTION_IF_NULL(second_node);
2298 graph->AddRefCorrespondPairs(AnfWithOutIndex(first_node, first_index),
2299 AnfWithOutIndex(second_node, second_index));
2300 }
2301 }
2302 if (graph_json.contains(kNodesKernelInfo)) {
2303 const auto &kernel_infos_json = graph_json[kNodesKernelInfo];
2304 LoadAnfKernelInfoFromJson(kernel_infos_json);
2305 }
2306 if (graph_json.contains(kGraphValueNodes)) {
2307 const auto &graph_value_nodes_json = graph_json[kGraphValueNodes];
2308 for (const auto &iter : graph_value_nodes_json) {
2309 auto node = context.FindBackNodeByBackName(iter);
2310 MS_EXCEPTION_IF_NULL(node);
2311 auto value_node = node->cast<ValueNodePtr>();
2312 MS_EXCEPTION_IF_NULL(value_node);
2313 graph->AddValueNodeToGraph(value_node);
2314 }
2315 }
2316 #ifndef ENABLE_SECURITY
2317 if (graph_json.contains(kSummaryNodes)) {
2318 const auto &summary_nodes_json = graph_json[kSummaryNodes];
2319 std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes;
2320 for (const auto &iter : summary_nodes_json) {
2321 const auto &first = iter.at(0);
2322 const auto &name = iter.at(kIndexOne);
2323 auto node = context.FindBackNodeByBackName(name);
2324 MS_EXCEPTION_IF_NULL(node);
2325 const auto &index = iter.at(kIndexTwo);
2326 summary_nodes[first] = std::make_pair(node, index);
2327 }
2328 graph->set_summary_nodes(summary_nodes);
2329 }
2330 #endif
2331 MS_LOG(INFO) << "Handle graph " << graph->ToString() << " complex attr success.";
2332 }
2333
ParseKernelGraphNodesAndAttrs(const nlohmann::json & model_json)2334 bool KernelGraphMgr::ParseKernelGraphNodesAndAttrs(const nlohmann::json &model_json) {
2335 auto &context = CompileCacheContext::GetInstance();
2336 for (auto &[graph_name, graph_json] : model_json.items()) {
2337 MS_LOG(DEBUG) << "Parse graph " << graph_name << " nodes and attrs.";
2338 KernelGraphPtr graph = graphs_.at(graph_json[kGraphId]);
2339 MS_EXCEPTION_IF_NULL(graph);
2340 HandleGraphInputsOutputs(graph_json, graph.get());
2341 const auto &back_to_front = graph_json[kBackendFrontAnf];
2342 for (const auto &[back_node_name, front_node_name] : back_to_front.items()) {
2343 auto back_node = context.FindBackNodeByBackName(back_node_name);
2344 if (!back_node) {
2345 MS_LOG(EXCEPTION) << "The backend node is nullptr, its unique name is " << back_node_name;
2346 }
2347 auto front_node = context.FindFrontNodeByFrontName(front_node_name);
2348 if (!front_node) {
2349 MS_LOG(EXCEPTION) << "The frontend node is nullptr, its unique name is " << front_node_name;
2350 }
2351 if (graph->FrontendNodeExistInFrontBackendMap(front_node)) {
2352 if (graph->BackendNodeExistInFrontBackendMap(back_node)) {
2353 continue;
2354 }
2355 auto old_back_node = graph->GetBackendAnfByFrontAnf(front_node);
2356 graph->FrontBackendMapAdd(old_back_node, back_node);
2357 } else {
2358 graph->FrontBackendMapAdd(front_node, back_node);
2359 }
2360 }
2361 HandleGraphSimpleAttr(graph_json, graph.get());
2362 HandleGraphComplexAttr(graphs_, graph_json, graph.get());
2363
2364 FuncGraphManagerPtr manager = MakeManager({graph});
2365 if (manager) {
2366 manager->AddFuncGraph(graph);
2367 graph->set_manager(manager);
2368 }
2369 graph->SetInputNodes();
2370 SetInputNodeUsage(graph, manager);
2371 graph->SetOptimizerFlag();
2372 }
2373 return true;
2374 }
2375
ResetGetNextSharedName(const FuncGraphPtr & graph)2376 void ResetGetNextSharedName(const FuncGraphPtr &graph) {
2377 auto &config_mgr = ConfigManager::GetInstance();
2378 auto queue_name = config_mgr.QueueName();
2379 auto cnodes = graph->GetOrderedCnodes();
2380 for (const auto &cnode : cnodes) {
2381 auto prim = GetValuePtr<Primitive>(cnode->input(0));
2382 if (prim != nullptr && prim->HasAttr("shared_name")) {
2383 prim->set_attr("shared_name", MakeValue(queue_name));
2384 break;
2385 }
2386 }
2387 }
2388
ConstructKernelGraph(std::vector<KernelGraphPtr> * all_out_graph)2389 std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(std::vector<KernelGraphPtr> *all_out_graph) {
2390 MS_LOG(WARNING) << "Use the compile cache to construct kernel graph, Be aware of correctness risks.";
2391 auto &context = CompileCacheContext::GetInstance();
2392 auto frontend_graph = context.FrontGraph();
2393 if (!frontend_graph) {
2394 MS_LOG(EXCEPTION) << "The frontend graph is null";
2395 }
2396 auto cache_path = context.GetBackendGraphCachePath(frontend_graph);
2397 std::string json_path = cache_path + kJsonSuffix;
2398 nlohmann::json model_json;
2399 auto load_json_success = LoadJson(json_path, &model_json);
2400 if (!load_json_success) {
2401 MS_LOG(EXCEPTION) << "Load json file " << json_path << " failed.";
2402 }
2403 // construct kernel graph and its params that exist correspond frontend param
2404 mindspore::HashMap<std::string, AnfNodePtr> name_to_node;
2405 (void)std::for_each(name_to_params_.begin(), name_to_params_.end(), [&name_to_node](const auto &ele) {
2406 if (!ele.second.expired()) {
2407 name_to_node[ele.first] = ele.second.lock();
2408 }
2409 });
2410 MS_LOG(DEBUG) << "Construct kernel graph and its params that exist correspond frontend param.";
2411 for (size_t i = 0; i < model_json.size(); i++) {
2412 auto kernel_graph = NewKernelGraph();
2413 all_out_graph->push_back(kernel_graph);
2414 const auto &graph_name = kernel_graph->ToString();
2415 if (!model_json.contains(graph_name)) {
2416 MS_LOG(EXCEPTION) << "Load graph " << graph_name << " from json failed.";
2417 }
2418 auto &graph_json = model_json[graph_name];
2419 if (!graph_json.contains(kCorrespondFrontendGraph)) {
2420 continue;
2421 }
2422 const auto &front_graph_name = graph_json[kCorrespondFrontendGraph];
2423 const auto &front_graph_node = context.FindFrontNodeByFrontName(front_graph_name);
2424 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(front_graph_node);
2425 MS_EXCEPTION_IF_NULL(fg);
2426 front_backend_graph_map_[fg.get()] = kernel_graph;
2427 if (graph_json.contains(kBackendParamToFrontendParamIndex)) {
2428 const auto &backend_param_to_frontend_param_index = graph_json[kBackendParamToFrontendParamIndex];
2429 const auto &frontend_graph_params = fg->parameters();
2430 for (const auto &[param_unique_name, index] : backend_param_to_frontend_param_index.items()) {
2431 const auto front_param = frontend_graph_params.at(index);
2432 MS_EXCEPTION_IF_NULL(front_param);
2433 MS_LOG(DEBUG) << "Start create new node, old node = " << front_param->DebugString()
2434 << ", new node = " << param_unique_name;
2435 auto new_parameter = CreateNewParameter(front_param, kernel_graph.get());
2436 kernel_graph->FrontBackendMapAdd(front_param, new_parameter);
2437 name_to_node[param_unique_name] = new_parameter;
2438 }
2439 }
2440 }
2441
2442 std::vector<FuncGraphPtr> graphs_for_load;
2443 (void)(std::transform(all_out_graph->begin(), all_out_graph->end(), std::back_inserter(graphs_for_load),
2444 [](const KernelGraphPtr &g) { return g; }));
2445
2446 MindIRLoader mindir_loader;
2447 std::string mindir_path = cache_path + kMindIrSuffix;
2448 auto real_path = Common::CreatePrefixPath(mindir_path, true);
2449 if (!CheckPath(real_path)) {
2450 MS_LOG(EXCEPTION) << "The mindir path is " << mindir_path << ", and it is a invalid path!";
2451 }
2452 if (!mindir_loader.LoadMindIR(real_path.value(), graphs_for_load, &name_to_node)) {
2453 MS_LOG(EXCEPTION) << "Load mindir from " << real_path.value() << " failed.";
2454 }
2455 (void)std::for_each(name_to_node.begin(), name_to_node.end(), [](const auto &ele) {
2456 auto node = ele.second;
2457 MS_EXCEPTION_IF_NULL(node);
2458 if (node->template isa<Parameter>()) {
2459 name_to_params_[ele.first] = std::weak_ptr<AnfNode>(node);
2460 }
2461 });
2462 context.SetBackNameToBackNode(name_to_node);
2463 // the value of attr "shared_name" will changed every time, so reset GetNext shared_name
2464 ResetGetNextSharedName(all_out_graph->front());
2465 if (!ParseKernelGraphNodesAndAttrs(model_json)) {
2466 MS_LOG(EXCEPTION) << "Parse kernel graph nodes and attrs failed.";
2467 }
2468
2469 #ifdef ENABLE_DUMP_IR
2470 auto ms_context = MsContext::GetInstance();
2471 MS_EXCEPTION_IF_NULL(ms_context);
2472 if (ms_context->CanDump(kIntroductory)) {
2473 for (const auto &iter : graphs_) {
2474 auto dump_name = std::string("loaded_") + iter.second->ToString() + ".ir";
2475 DumpIR(dump_name, iter.second);
2476 }
2477 }
2478 #endif
2479 MS_LOG(WARNING)
2480 << "Use the compile cache to construct kernel graph success, and will execute the preprocess before run directly.";
2481 return all_out_graph->front();
2482 }
2483
SetInputNodeUsage(const KernelGraphPtr & graph,const FuncGraphManagerPtr & manager) const2484 void KernelGraphMgr::SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager) const {
2485 MS_EXCEPTION_IF_NULL(graph);
2486 MS_EXCEPTION_IF_NULL(manager);
2487 auto input_nodes = graph->input_nodes();
2488 for (auto &input_node : input_nodes) {
2489 MS_EXCEPTION_IF_NULL(input_node);
2490 if (input_node->isa<Parameter>()) {
2491 auto node_ptr = input_node->cast<ParameterPtr>();
2492 MS_EXCEPTION_IF_NULL(node_ptr);
2493 if (!IsUsedByRealKernel(manager, input_node, graph->graph_id())) {
2494 node_ptr->SetNotUsedByRealKernelInGraph(graph->graph_id());
2495 }
2496 auto shape = node_ptr->Shape();
2497 MS_EXCEPTION_IF_NULL(shape);
2498 if (shape->isa<abstract::Shape>() && shape->IsDynamic()) {
2499 node_ptr->set_has_dynamic_shape(true);
2500 }
2501 if (input_node->abstract() != nullptr && input_node->abstract()->isa<abstract::AbstractSequence>()) {
2502 // If the parameter is dynamic sequence, it is regard as dynamic shape.
2503 const auto &tuple_abs = input_node->abstract()->cast<abstract::AbstractSequencePtr>();
2504 MS_EXCEPTION_IF_NULL(tuple_abs);
2505 if (tuple_abs->dynamic_len()) {
2506 MS_LOG(INFO) << "Input node:" << input_node->DebugString() << " set dynamic flag to true";
2507 node_ptr->set_has_dynamic_shape(true);
2508 }
2509 }
2510 }
2511 }
2512 }
2513
2514 namespace {
CNodeFirstInputIsPrimitive(const AnfNodePtr & node)2515 bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
2516 if (node == nullptr) {
2517 return false;
2518 }
2519 auto cnode = node->cast<CNodePtr>();
2520 if (cnode == nullptr) {
2521 return false;
2522 }
2523 auto prim = cnode->input(kAnfPrimitiveIndex);
2524 if (prim == nullptr || !IsValueNode<Primitive>(prim)) {
2525 return false;
2526 }
2527 return true;
2528 }
2529
ExtendNodeUsers(const FuncGraphManagerPtr & front_func_graph_manager,const AnfNodePtr & front_node)2530 std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
2531 const AnfNodePtr &front_node) {
2532 MS_EXCEPTION_IF_NULL(front_func_graph_manager);
2533 auto &users = front_func_graph_manager->node_users()[front_node];
2534 std::vector<AnfNodePtr> result;
2535 for (auto &user : users) {
2536 if (common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend) ||
2537 common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimLoad)) {
2538 auto depend_cnode = user.first->cast<CNodePtr>();
2539 if (depend_cnode == nullptr) {
2540 continue;
2541 }
2542 if (front_node != depend_cnode->input(1)) {
2543 continue;
2544 }
2545 auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
2546 (void)result.insert(result.cend(), res.cbegin(), res.cend());
2547 } else if (common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
2548 auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
2549 (void)result.insert(result.cend(), res.cbegin(), res.cend());
2550 } else {
2551 (void)result.emplace_back(user.first);
2552 }
2553 }
2554 return result;
2555 }
2556
GetSupportedInternalNode(const AnfNodePtr & front_node)2557 AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
2558 MS_EXCEPTION_IF_NULL(front_node);
2559 if (!front_node->isa<CNode>()) {
2560 return nullptr;
2561 }
2562 if (AnfUtils::IsRealKernel(front_node)) {
2563 return front_node;
2564 }
2565 if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
2566 return front_node;
2567 }
2568 if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimMakeTuple)) {
2569 auto cnode = front_node->cast<CNodePtr>();
2570 MS_EXCEPTION_IF_NULL(cnode);
2571 auto &inputs = cnode->inputs();
2572 if (inputs.size() > 1) {
2573 return GetSupportedInternalNode(inputs[1]);
2574 }
2575 }
2576 if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
2577 auto cnode = front_node->cast<CNodePtr>();
2578 MS_EXCEPTION_IF_NULL(cnode);
2579 auto &inputs = cnode->inputs();
2580 if (inputs.size() >= kDependInputSize) {
2581 return GetSupportedInternalNode(inputs[kRealInputIndexInDepend]);
2582 }
2583 }
2584 return nullptr;
2585 }
2586
IsUnusedInternlOutput(const AnfNodePtr & user)2587 bool IsUnusedInternlOutput(const AnfNodePtr &user) {
2588 if (!CNodeFirstInputIsPrimitive(user)) {
2589 return true;
2590 }
2591 if (IsPrimitiveCNode(user, prim::kPrimSwitch) || IsPrimitiveCNode(user, prim::kPrimSwitchLayer)) {
2592 return true;
2593 }
2594 if (!AnfUtils::IsRealKernel(user)) {
2595 return true;
2596 }
2597 return false;
2598 }
2599 } // namespace
2600
2601 constexpr auto kMixTarget = "MixTarget";
2602 constexpr auto kNoTarget = "NoTarget";
AddPartialParametersMap(const AnfNodePtr & partial_node)2603 std::string KernelGraphMgr::AddPartialParametersMap(const AnfNodePtr &partial_node) {
2604 MS_EXCEPTION_IF_NULL(partial_node);
2605 auto iter = partial_target_map_.find(partial_node);
2606 if (iter != partial_target_map_.end()) {
2607 return iter->second;
2608 }
2609 auto partial_cnode = partial_node->cast<CNodePtr>();
2610 MS_EXCEPTION_IF_NULL(partial_cnode);
2611 auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
2612 // If graph is nullptr, it means that the funcgraph in the partial node is a deadnode, and the processing is skipped.
2613 if (partial_graph == nullptr) {
2614 return kNoTarget;
2615 }
2616 auto parameters = partial_graph->parameters();
2617 auto partial_inputs = partial_cnode->inputs();
2618 const size_t kNonParameterNum = 2;
2619 if (parameters.size() + kNonParameterNum != partial_inputs.size()) {
2620 return kMixTarget;
2621 }
2622 for (size_t i = 0; i < parameters.size(); ++i) {
2623 partial_parameters_map_[parameters[i]] = partial_inputs[kNonParameterNum + i];
2624 }
2625 auto graph_nodes = TopoSort(partial_graph->get_return());
2626 std::string graph_target = kNoTarget;
2627 for (auto &node : graph_nodes) {
2628 if (!node->isa<CNode>()) {
2629 continue;
2630 }
2631 if (!AnfUtils::IsRealKernel(node)) {
2632 continue;
2633 }
2634 std::string cur_target = GetCNodeTarget(node);
2635 if (graph_target == kNoTarget) {
2636 graph_target = cur_target;
2637 }
2638 if (graph_target != cur_target) {
2639 graph_target = kMixTarget;
2640 break;
2641 }
2642 }
2643 (void)partial_target_map_.emplace(std::pair<AnfNodePtr, std::string>(partial_node, graph_target));
2644 return graph_target;
2645 }
2646
2647 namespace {
IsNeedAddPartialParameter(const AnfNodePtr & user,const std::string & kernel_target,const std::shared_ptr<KernelGraph> & graph)2648 bool IsNeedAddPartialParameter(const AnfNodePtr &user, const std::string &kernel_target,
2649 const std::shared_ptr<KernelGraph> &graph) {
2650 // If the flag is enable, it means the graph would run in subgraph sink mode, the real parameter on partial
2651 // cannot share the same device address with the formal parameter.
2652 MS_EXCEPTION_IF_NULL(graph);
2653 return common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
2654 !ExistGraphCaller(user) && (!graph->has_flag(kFlagEnableZeroCopyInGraph));
2655 }
2656 } // namespace
2657
HandleInternalOutput(const AnfNodePtr & input_front_node,const AnfNodePtr & backend_node,const FuncGraphManagerPtr & front_func_graph_manager,const std::shared_ptr<KernelGraph> & backend_graph)2658 void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
2659 const FuncGraphManagerPtr &front_func_graph_manager,
2660 const std::shared_ptr<KernelGraph> &backend_graph) {
2661 MS_EXCEPTION_IF_NULL(backend_graph);
2662 auto front_node = GetSupportedInternalNode(input_front_node);
2663 if (front_node == nullptr) {
2664 return;
2665 }
2666 auto front_real_kernel_pair = common::AnfAlgo::VisitKernel(front_node, 0);
2667 auto backend_real_kernel_pair = common::AnfAlgo::VisitKernel(backend_node, 0);
2668 auto backend_real_kernel = backend_real_kernel_pair.first;
2669 if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) {
2670 return;
2671 }
2672 auto front_real_kernel = front_real_kernel_pair.first;
2673 std::string kernel_target = GetCNodeTarget(front_real_kernel);
2674 bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
2675 bool unique_target = true;
2676 if (internal_output && common::AnfAlgo::IsNopNode(front_real_kernel)) {
2677 auto pre_node_pair = common::AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
2678 auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
2679 if (pre_node_target != kernel_target) {
2680 unique_target = false;
2681 }
2682 }
2683 if (internal_output) {
2684 auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
2685 for (auto &user : users) {
2686 if (IsNeedAddPartialParameter(user, kernel_target, backend_graph)) {
2687 auto partial_target = AddPartialParametersMap(user);
2688 if (partial_target != kNoTarget && partial_target != kernel_target) {
2689 unique_target = false;
2690 }
2691 continue;
2692 }
2693 if (common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
2694 continue;
2695 }
2696 if (IsUnusedInternlOutput(user)) {
2697 internal_output = false;
2698 break;
2699 }
2700 if (kernel_target != GetCNodeTarget(user)) {
2701 unique_target = false;
2702 }
2703 }
2704 }
2705 if (internal_output) {
2706 MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString()
2707 << ", unique_target: " << unique_target;
2708 backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
2709 }
2710 }
2711
ConstructOutput(const AnfNodePtrList & outputs,const std::shared_ptr<KernelGraph> & graph)2712 CNodePtr KernelGraphMgr::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
2713 MS_EXCEPTION_IF_NULL(graph);
2714 std::vector<AnfNodePtr> output_args;
2715 for (const auto &output : outputs) {
2716 MS_EXCEPTION_IF_NULL(output);
2717 MS_LOG(INFO) << "Output:" << output->DebugString();
2718 }
2719 auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
2720 auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
2721 if (backend_anf != nullptr) {
2722 auto context_ptr = MsContext::GetInstance();
2723 MS_EXCEPTION_IF_NULL(context_ptr);
2724 if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
2725 return backend_anf;
2726 }
2727
2728 MS_EXCEPTION_IF_NULL(out);
2729 auto out_func_graph = out->func_graph();
2730 MS_EXCEPTION_IF_NULL(out_func_graph);
2731 auto out_func_graph_manager = out_func_graph->manager();
2732 if (out_func_graph_manager == nullptr) {
2733 return backend_anf;
2734 }
2735 HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph);
2736 return backend_anf;
2737 }
2738 MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
2739 };
2740 output_args.push_back(mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())));
2741 (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
2742 [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
2743 auto output_node = graph->NewCNode(output_args);
2744 MS_EXCEPTION_IF_NULL(output_node);
2745 // Create abstract for output maketuple node.
2746 AbstractBasePtrList output_abs_list;
2747 const auto &inputs = output_node->inputs();
2748 (void)std::transform(
2749 inputs.begin() + 1, inputs.end(), std::back_inserter(output_abs_list), [](const AnfNodePtr &input) {
2750 return input->abstract() == nullptr ? std::make_shared<abstract::AbstractNone>() : input->abstract();
2751 });
2752 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(output_abs_list);
2753 MS_EXCEPTION_IF_NULL(abstract_tuple);
2754 output_node->set_abstract(abstract_tuple);
2755 return output_node;
2756 }
2757
NewPynativeKernelGraph()2758 KernelGraphPtr KernelGraphMgr::NewPynativeKernelGraph() {
2759 auto graph = std::make_shared<KernelGraph>();
2760 graph->set_is_from_pynative(true);
2761 MS_EXCEPTION_IF_NULL(graph);
2762 graph->set_graph_id(pynative_graph_sum_);
2763 pynative_graph_sum_++;
2764 return graph;
2765 }
2766
NewKernelGraph()2767 KernelGraphPtr KernelGraphMgr::NewKernelGraph() {
2768 auto graph = std::make_shared<KernelGraph>();
2769 MS_EXCEPTION_IF_NULL(graph);
2770 SetKernelGraphId(graph);
2771 return graph;
2772 }
2773
SetKernelGraphId(const KernelGraphPtr & kernel_graph)2774 void KernelGraphMgr::SetKernelGraphId(const KernelGraphPtr &kernel_graph) {
2775 MS_EXCEPTION_IF_NULL(kernel_graph);
2776 if (graph_sum_ >= kPynativeGraphIdStart) {
2777 MS_LOG(EXCEPTION) << "The graph id in graph mode must be less than " << kPynativeGraphIdStart << ", but it is "
2778 << graph_sum_;
2779 }
2780 kernel_graph->set_graph_id(graph_sum_);
2781 graphs_[graph_sum_++] = kernel_graph;
2782 }
2783
UnifyMindIR(const KernelGraphPtr & graph)2784 void KernelGraphMgr::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIR(graph); }
2785
2786 namespace {
CopyCNodeInfo(const FuncGraphPtr & func_graph,const uint32_t & target_graph_id,const AnfNodePtr & ori_node,const AnfNodePtr & new_node)2787 void CopyCNodeInfo(const FuncGraphPtr &func_graph, const uint32_t &target_graph_id, const AnfNodePtr &ori_node,
2788 const AnfNodePtr &new_node) {
2789 MS_EXCEPTION_IF_NULL(new_node);
2790 MS_EXCEPTION_IF_NULL(ori_node);
2791 auto kernel_info = dynamic_cast<device::KernelInfo *>(new_node->kernel_info());
2792 // deep copy kernel info
2793 if (kernel_info != nullptr && kernel_info->has_build_info()) {
2794 // some check
2795 MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->MutableKernelMod() == nullptr,
2796 "Inline ERROR: " + ori_node->DebugString() + ", kernel mod is not nullptr");
2797 MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->output_address_list().empty(),
2798 "Inline ERROR: " + ori_node->DebugString() + ", output_address_list is not empty");
2799 MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->workspace_address_list().empty(),
2800 "Inline ERROR: " + ori_node->DebugString() + ", workspace_address_list is not empty");
2801
2802 auto new_kernel_info = std::make_shared<device::KernelInfo>();
2803 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(
2804 AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(new_node));
2805 MS_EXCEPTION_IF_NULL(builder);
2806 MS_EXCEPTION_IF_NULL(new_kernel_info);
2807 new_kernel_info->set_select_kernel_build_info(builder->Build());
2808 new_kernel_info->set_graph_id(target_graph_id);
2809 new_kernel_info->set_feature_map_flag(kernel_info->is_feature_map());
2810 new_kernel_info->set_ref_map(false, kernel_info->out_in_ref_map());
2811 new_node->set_kernel_info(new_kernel_info);
2812 } else {
2813 auto new_kernel_info = std::make_shared<device::KernelInfo>();
2814 new_node->set_kernel_info(new_kernel_info);
2815 }
2816 if (ori_node->isa<CNode>()) {
2817 auto ori_cnode = ori_node->cast<CNodePtr>();
2818 if (common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, ori_cnode) &&
2819 common::AnfAlgo::GetNodeAttr<bool>(ori_node, kAttrIsUBFusionOp)) {
2820 // already done fusion compile
2821 auto ori_full_name = ori_cnode->fullname_with_scope();
2822 common::AnfAlgo::SetNodeAttr(kAttrOriFusionName, MakeValue(ori_full_name), new_node);
2823 }
2824 common::AnfAlgo::SetNodeAttr(kAttrNeedInline, MakeValue(ori_node->fullname_with_scope()), new_node);
2825 common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node);
2826 }
2827 }
2828
UpdateConditionNodePair(const KernelGraphPtr & kernel_graph,const KernelGraphPtr & target_kernel_graph,const mindspore::HashMap<AnfNodePtr,AnfNodePtr> & condition_node_map)2829 void UpdateConditionNodePair(const KernelGraphPtr &kernel_graph, const KernelGraphPtr &target_kernel_graph,
2830 const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &condition_node_map) {
2831 MS_EXCEPTION_IF_NULL(kernel_graph);
2832 const auto &gather_to_switch = kernel_graph->condition_gather_to_switch();
2833 for (const auto &pair : gather_to_switch) {
2834 MS_EXCEPTION_IF_NULL(pair.first);
2835 MS_EXCEPTION_IF_NULL(pair.second);
2836 const auto &gather_iter = condition_node_map.find(pair.first);
2837 const auto &switch_iter = condition_node_map.find(pair.second);
2838 if (gather_iter == condition_node_map.end() || switch_iter == condition_node_map.end()) {
2839 MS_LOG(EXCEPTION) << "Failed to get new gather node:" << pair.first->fullname_with_scope()
2840 << " or switch node:" << pair.second->fullname_with_scope()
2841 << " in graph:" << kernel_graph->ToString();
2842 }
2843 MS_EXCEPTION_IF_NULL(gather_iter->second);
2844 MS_EXCEPTION_IF_NULL(switch_iter->second);
2845 if (target_kernel_graph != nullptr) {
2846 target_kernel_graph->AddConditionGatherSwitchPair(gather_iter->second, switch_iter->second);
2847 MS_LOG(INFO) << "Add condition node pair:" << gather_iter->second->fullname_with_scope()
2848 << " and:" << switch_iter->second->fullname_with_scope()
2849 << " for graph:" << target_kernel_graph->ToString();
2850 const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(pair.second);
2851 if (front_node == nullptr) {
2852 MS_LOG(WARNING) << "Failed to get front node by backend node:" << pair.second->DebugString()
2853 << " in graph:" << kernel_graph->ToString();
2854 continue;
2855 }
2856 target_kernel_graph->FrontBackendMapAdd(front_node, switch_iter->second);
2857 }
2858 }
2859 }
2860 } // namespace
2861
DoInline(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & func_graph_args,const ScopePtr & scope,const uint32_t & target_graph_id,const std::map<session::AnfWithOutIndex,session::AnfWithOutIndex> & ref_map,const KernelGraphPtr & graph,bool is_switch_inline)2862 AnfNodePtr KernelGraphMgr::DoInline(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
2863 const AnfNodePtrList &func_graph_args, const ScopePtr &scope,
2864 const uint32_t &target_graph_id,
2865 const std::map<session::AnfWithOutIndex, session::AnfWithOutIndex> &ref_map,
2866 const KernelGraphPtr &graph, bool is_switch_inline) {
2867 MS_EXCEPTION_IF_NULL(func_graph);
2868 MS_EXCEPTION_IF_NULL(graph);
2869 MS_EXCEPTION_IF_NULL(target_func_graph);
2870 KernelGraphPtr target_kernel_graph = nullptr;
2871 if (target_func_graph->isa<KernelGraph>()) {
2872 target_kernel_graph = target_func_graph->cast<KernelGraphPtr>();
2873 }
2874 Cloner cloner({}, false);
2875 if (scope != nullptr) {
2876 cloner.set_scope(scope);
2877 }
2878 cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
2879 auto node_list = TopoSort(func_graph->output());
2880 mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_node_map;
2881 for (auto &ori_node : node_list) {
2882 MS_EXCEPTION_IF_NULL(ori_node);
2883 if (ori_node->isa<Parameter>()) {
2884 continue;
2885 }
2886 auto new_node = cloner[ori_node];
2887 MS_EXCEPTION_IF_NULL(new_node);
2888 if (new_node->isa<ValueNode>()) {
2889 auto value_node = new_node->cast<ValueNodePtr>();
2890 MS_EXCEPTION_IF_NULL(value_node);
2891 graph->AddValueNodeToGraph(value_node);
2892 }
2893 // Add sub graph kernel for switch inline kernel graph.
2894 if (new_node->isa<CNode>() && target_kernel_graph != nullptr && is_switch_inline) {
2895 MS_LOG(DEBUG) << "Add inline sub graph for kernel:" << new_node->fullname_with_scope()
2896 << " graph:" << func_graph->ToString();
2897 std::string sub_graph_name = func_graph->ToString();
2898 if (func_graph->isa<KernelGraph>()) {
2899 const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
2900 MS_EXCEPTION_IF_NULL(kernel_graph);
2901 const auto &sub_graph_iter = kernel_graph->inline_sub_graph_kernels().find(ori_node);
2902 if (sub_graph_iter != kernel_graph->inline_sub_graph_kernels().end()) {
2903 sub_graph_name = sub_graph_iter->second;
2904 }
2905 }
2906 target_kernel_graph->AddInlineSubgraphKernel(new_node, sub_graph_name);
2907 if (common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionGather) ||
2908 common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionSwitch)) {
2909 condition_node_map[ori_node] = new_node;
2910 }
2911 }
2912 CopyCNodeInfo(func_graph, target_graph_id, ori_node, new_node);
2913 }
2914 // Collect condition gather node and condition switch node.
2915 if (func_graph->isa<KernelGraph>() && is_switch_inline) {
2916 const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
2917 MS_EXCEPTION_IF_NULL(kernel_graph);
2918 UpdateConditionNodePair(kernel_graph, target_kernel_graph, condition_node_map);
2919 }
2920
2921 for (const auto &kv : ref_map) {
2922 auto final_pair = kv.first;
2923 auto origin_pair = kv.second;
2924 final_pair.first = cloner[final_pair.first];
2925 origin_pair.first = cloner[origin_pair.first];
2926 auto new_origin_pair = common::AnfAlgo::VisitKernel(origin_pair.first, origin_pair.second);
2927 graph->AddRefCorrespondPairs(final_pair, new_origin_pair);
2928 }
2929 return cloner[func_graph->output()];
2930 }
2931 } // namespace session
2932 } // namespace mindspore
2933