• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/session/kernel_graph_mgr.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <utility>
21 #include <functional>
22 #include <unordered_map>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "include/common/debug/anf_ir_dump.h"
26 #include "runtime/device/kernel_runtime_manager.h"
27 #include "backend/common/optimizer/common_backend_optimization.h"
28 #include "pipeline/pynative/grad/jit/jit_call_graph.h"
29 #include "utils/trace_base.h"
30 #include "ir/func_graph_cloner.h"
31 #ifndef ENABLE_SECURITY
32 #include "include/backend/debug/data_dump/dump_json_parser.h"
33 #include "include/backend/debug/data_dump/e2e_dump.h"
34 #endif
35 #include "include/common/utils/compile_cache_context.h"
36 #include "include/common/utils/config_manager.h"
37 #include "load_mindir/load_model.h"
38 #include "include/common/debug/dump_proto.h"
39 
40 namespace mindspore {
41 namespace session {
42 namespace {
43 constexpr size_t kMaxDepth = 128;
44 constexpr size_t kFirstIndex = 1;
45 constexpr int64_t kPairIdx1 = 1;
46 // uint32_t max value is 4294967295
47 // graph id in graph mode start from 0
48 // graph id in pynative mode start from 4000000000
49 constexpr uint32_t kPynativeGraphIdStart = 4000000000;
50 
IsGeReturnNode(const AnfNodePtr & node)51 bool IsGeReturnNode(const AnfNodePtr &node) {
52   auto context = MsContext::GetInstance();
53   MS_EXCEPTION_IF_NULL(context);
54   const bool enable_ge = context->backend_policy() == "ge";
55   if (!enable_ge) {
56     return false;
57   }
58   MS_EXCEPTION_IF_NULL(node);
59   auto cnode = node->cast<CNodePtr>();
60   if (cnode == nullptr) {
61     // parameter and value node is a real kernel too
62     return true;
63   }
64   if (cnode->size() == 0) {
65     MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
66                                << trace::DumpSourceLines(node);
67   }
68   return IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), {prim::kPrimReturn});
69 }
70 
RecursiveCheck(const FuncGraphManagerPtr & manager,const std::pair<AnfNodePtr,int64_t> & kernel,size_t * idx)71 bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodePtr, int64_t> &kernel, size_t *idx) {
72   auto node = kernel.first;
73   MS_EXCEPTION_IF_NULL(manager);
74   MS_EXCEPTION_IF_NULL(node);
75   if (kernel.second > kPairIdx1 && (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) ||
76                                     common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
77     return false;
78   }
79   if ((AnfUtils::IsRealKernel(node) || IsGeReturnNode(node) || AnfAlgo::IsSummaryNode(node)) &&
80       !common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
81     return true;
82   }
83   (*idx) += 1;
84   // max recursion depth
85   if (*idx <= kMaxDepth) {
86     auto users = manager->node_users()[node];
87     if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
88           return RecursiveCheck(manager, kernel, idx);
89         })) {
90       return true;
91     }
92   }
93   return false;
94 }
95 
IsUsedByRealKernel(const FuncGraphManagerPtr & manager,const AnfNodePtr & node,const uint32_t graph_id)96 bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, const uint32_t graph_id) {
97   MS_EXCEPTION_IF_NULL(manager);
98   MS_EXCEPTION_IF_NULL(node);
99   auto node_users = manager->node_users()[node];
100   // filter nodes not in current graph
101   for (auto iter = node_users.begin(); iter != node_users.end();) {
102     auto func_graph = iter->first->func_graph();
103     MS_EXCEPTION_IF_NULL(func_graph);
104     auto kernel_graph = func_graph->cast<KernelGraphPtr>();
105     if (kernel_graph == nullptr) {
106       MS_LOG(EXCEPTION) << "func graph cast kernel graph failed, related node is: " << iter->first->DebugString();
107     }
108     if (kernel_graph->graph_id() != graph_id) {
109       iter = node_users.erase(iter);
110     } else {
111       iter++;
112     }
113   }
114 
115   size_t idx = 0;
116   if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
117         return RecursiveCheck(manager, kernel, &idx);
118       })) {
119     return true;
120   }
121   return false;
122 }
123 
ExistGraphCaller(const AnfNodePtr & partial_node)124 bool ExistGraphCaller(const AnfNodePtr &partial_node) {
125   MS_EXCEPTION_IF_NULL(partial_node);
126   auto partial_cnode = partial_node->cast<CNodePtr>();
127   MS_EXCEPTION_IF_NULL(partial_cnode);
128   auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
129   // If graph is nullptr, it means that the funcgraph in the partial node is a deadnode, and the processing is skipped.
130   if (partial_graph == nullptr) {
131     return false;
132   }
133   auto graph_nodes = TopoSort(partial_graph->get_return());
134   return std::any_of(graph_nodes.begin(), graph_nodes.end(), IsValueNode<FuncGraph>);
135 }
136 
CheckPath(const std::optional<std::string> & path)137 bool CheckPath(const std::optional<std::string> &path) {
138   if (!path.has_value()) {
139     return false;
140   }
141   std::ifstream f(path.value());
142   bool file_is_good = f.good();
143   f.close();
144   if (!file_is_good) {
145     MS_LOG(WARNING) << "Open the compilation cache file " << path.value() << " failed.";
146     return false;
147   }
148   return true;
149 }
150 
LoadJson(const std::string & filename,nlohmann::json * graph_json)151 bool LoadJson(const std::string &filename, nlohmann::json *graph_json) {
152   std::ifstream json_fs(filename);
153   if (!json_fs.is_open()) {
154     MS_LOG(ERROR) << "Open json file: " << filename << " error, backend graph cache Missed.";
155     return false;
156   }
157   try {
158     json_fs >> *graph_json;
159     json_fs.close();
160   } catch (std::exception &e) {
161     MS_LOG(INFO) << "Parse json file error: " << filename << ", sleep 500ms and retry again.";
162     json_fs.close();
163     std::this_thread::sleep_for(std::chrono::milliseconds(kRetryIntervalMilliSeconds));
164     std::ifstream retry_tmp(filename);
165     if (!retry_tmp.is_open()) {
166       MS_LOG(ERROR) << "Open json file: " << filename << " error, please check cached file.";
167       return false;
168     }
169     retry_tmp >> *graph_json;
170     retry_tmp.close();
171   }
172   return true;
173 }
174 
175 template <typename Type>
StringToNum(const std::string & str)176 Type StringToNum(const std::string &str) {
177   std::istringstream iss(str);
178   Type num;
179   iss >> num;
180   return num;
181 }
182 
LoadKernelInfoRuntimeCache(const nlohmann::json & kernel_info_value,std::shared_ptr<KernelInfoDevice> kernel_info)183 void LoadKernelInfoRuntimeCache(const nlohmann::json &kernel_info_value,
184                                 std::shared_ptr<KernelInfoDevice> kernel_info) {
185   if (!kernel_info_value.contains(kRuntimeCacheValid)) {
186     return;
187   }
188   auto &context = CompileCacheContext::GetInstance();
189   auto &rt = kernel_info->runtime_cache().runtime_cache();
190   rt.set_is_valid(kernel_info_value[kRuntimeCacheValid]);
191   rt.set_device_target(kernel_info_value[kRuntimeCacheDeviceTarget]);
192   rt.set_output_tensor_num(kernel_info_value[kRuntimeCacheOutputTensorNum]);
193   rt.set_real_kernel(kernel_info_value[kRuntimeCacheIsRealKernel]);
194   if (kernel_info_value.contains(kRuntimeCachePrevOutputs)) {
195     const auto &prev_outputs = kernel_info_value[kRuntimeCachePrevOutputs];
196     for (const auto &prev_output : prev_outputs) {
197       const auto &first_index = prev_output.at(0);
198       const auto &name = prev_output.at(kIndexOne);
199       const auto &second_index = prev_output.at(kIndexTwo);
200       auto output_node = context.FindBackNodeByBackName(name);
201       MS_EXCEPTION_IF_NULL(output_node);
202       rt.update_prev_node_output(first_index, std::make_pair(output_node, second_index));
203     }
204   }
205 }
206 
LoadAnfSelectKernelBuildInfo(const nlohmann::json & kernel_info_value,const AnfNodePtr & node)207 void LoadAnfSelectKernelBuildInfo(const nlohmann::json &kernel_info_value, const AnfNodePtr &node) {
208   if (!kernel_info_value.contains(kHasSelectKernelBuildInfo)) {
209     return;
210   }
211   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
212   MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
213 
214   auto kernel_type = kernel_info_value[kKernelType];
215   kernel_build_info_builder->SetKernelType(kernel_type);
216 
217   auto op_type = kernel_info_value[kOpType];
218   kernel_build_info_builder->SetOpType(op_type);
219   auto core_type = kernel_info_value[kCoreType];
220   kernel_build_info_builder->SetCoreType(core_type);
221 
222   if (kernel_info_value.contains(kOriginDataFormat)) {
223     auto origin_data_format = kernel_info_value[kOriginDataFormat];
224     kernel_build_info_builder->SetOriginDataFormat(origin_data_format);
225   }
226 
227   if (kernel_info_value.contains(kAllInputFormat)) {
228     auto all_input_format = kernel_info_value[kAllInputFormat];
229     kernel_build_info_builder->SetInputsFormat(all_input_format);
230   }
231 
232   auto pattern = kernel_info_value[kOpPattern];
233   kernel_build_info_builder->SetOpPattern(pattern);
234   if (kernel_info_value.contains(kAllOutputFormat)) {
235     auto all_output_format = kernel_info_value[kAllOutputFormat];
236     kernel_build_info_builder->SetOutputsFormat(all_output_format);
237   }
238   if (kernel_info_value.contains(kAllInputReshapeType)) {
239     auto all_input_reshape_type = kernel_info_value[kAllInputReshapeType];
240     kernel_build_info_builder->SetInputsReshapeType(all_input_reshape_type);
241   }
242 
243   if (kernel_info_value.contains(kAllOutputReshapeType)) {
244     auto all_output_reshape_type = kernel_info_value[kAllOutputReshapeType];
245     kernel_build_info_builder->SetOutputsReshapeType(all_output_reshape_type);
246   }
247   if (kernel_info_value.contains(kAllInputDeviceType)) {
248     auto all_input_device_type = kernel_info_value[kAllInputDeviceType];
249     kernel_build_info_builder->SetInputsDeviceType(all_input_device_type);
250   }
251   if (kernel_info_value.contains(kAllOutputDeviceType)) {
252     auto all_output_device_type = kernel_info_value[kAllOutputDeviceType];
253     kernel_build_info_builder->SetOutputsDeviceType(all_output_device_type);
254   }
255   if (kernel_info_value.contains(kInputKernelObjectTypes)) {
256     auto input_kernel_object_types = kernel_info_value[kInputKernelObjectTypes];
257     kernel_build_info_builder->SetInputsKernelObjectType(input_kernel_object_types);
258   }
259   if (kernel_info_value.contains(kOutputKernelObjectTypes)) {
260     auto output_kernel_object_types = kernel_info_value[kOutputKernelObjectTypes];
261     kernel_build_info_builder->SetOutputsKernelObjectType(output_kernel_object_types);
262   }
263   if (kernel_info_value.contains(kOutputElementsKernelObjectTypes)) {
264     auto output_elements_kernel_object_types = kernel_info_value[kOutputElementsKernelObjectTypes];
265     kernel_build_info_builder->SetOutputElementsKernelObjectType(output_elements_kernel_object_types);
266   }
267 
268   if (kernel_info_value.contains(kOutputDataDesc)) {
269     auto output_data_desc = kernel_info_value[kOutputDataDesc];
270     kernel_build_info_builder->SetOutputDataDesc(output_data_desc);
271   }
272   auto fusion_type = kernel_info_value[kFusionType];
273   kernel_build_info_builder->SetFusionType(fusion_type);
274   auto processor = kernel_info_value[kProcessor];
275   kernel_build_info_builder->SetProcessor(processor);
276   auto valid = kernel_info_value[kKernelBuildInfoValid];
277   kernel_build_info_builder->SetValid(valid);
278   const auto &kernel_build = kernel_build_info_builder->Build();
279   AnfAlgo::SetSelectKernelBuildInfo(kernel_build, node.get());
280 }
281 
LoadAnfKernelInfoFromJson(const nlohmann::json & kernel_infos_json)282 void LoadAnfKernelInfoFromJson(const nlohmann::json &kernel_infos_json) {
283   auto &context = CompileCacheContext::GetInstance();
284   for (const auto &[name, kernel_info_value] : kernel_infos_json.items()) {
285     auto node = context.FindBackNodeByBackName(name);
286     MS_EXCEPTION_IF_NULL(node);
287     MS_LOG(DEBUG) << "Load node " << node->DebugString() << " kernel info from json.";
288     auto kernel_info = std::make_shared<device::KernelInfo>();
289     MS_EXCEPTION_IF_NULL(kernel_info);
290     if (kernel_info_value.contains(kOutInRef)) {
291       const auto &out_in_ref_json = kernel_info_value[kOutInRef];
292       for (const auto &[out, in] : out_in_ref_json.items()) {
293         kernel_info->AddRefMap(StringToNum<size_t>(out), in);
294       }
295     }
296     kernel_info->set_graph_id(kernel_info_value[kGraphId]);
297     kernel_info->set_feature_map_flag(kernel_info_value[kIsFeatureMap]);
298     node->set_kernel_info(kernel_info);
299     LoadAnfSelectKernelBuildInfo(kernel_info_value, node);
300 
301     if (node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, node->cast<CNodePtr>()) &&
302         common::AnfAlgo::GetNodeAttr<bool>(node->cast<CNodePtr>(), kAttrIsUBFusionOp)) {
303       if (kernel_info_value.contains(kJsonName) && kernel_info_value.contains(kInputSizeList) &&
304           kernel_info_value.contains(kOutputSizeList)) {
305         CachedIOSizeInfo io_size;
306         io_size.json_name = kernel_info_value[kJsonName];
307         const auto &input_size_list = kernel_info_value[kInputSizeList];
308         const auto &output_size_list = kernel_info_value[kOutputSizeList];
309         (void)(io_size.input_size_list.insert(io_size.input_size_list.end(), input_size_list.begin(),
310                                               input_size_list.end()));
311         (void)(io_size.output_size_list.insert(io_size.output_size_list.end(), output_size_list.begin(),
312                                                output_size_list.end()));
313         context.PushFullnameIoSizeInfo(node->fullname_with_scope(), io_size);
314       } else {
315         MS_LOG(EXCEPTION) << "Load node " << node->DebugString() << " kernel_io_size_info failed.";
316       }
317     }
318     LoadKernelInfoRuntimeCache(kernel_info_value, kernel_info);
319   }
320 }
321 
GetAnfUniqueCacheName(const AnfNodePtr & node,bool must_have_unique_name=true)322 std::string GetAnfUniqueCacheName(const AnfNodePtr &node, bool must_have_unique_name = true) {
323   MS_EXCEPTION_IF_NULL(node);
324   const auto &name = node->user_data<std::string>(kUniqueCacheName);
325   if (must_have_unique_name && name == nullptr) {
326     MS_LOG(EXCEPTION) << "The node " << node->DebugString()
327                       << " has not unique name, indicating that it is not exported to mindir.";
328   }
329   return name != nullptr ? *name : "";
330 }
331 
SaveAnfToAnfMap(const HashMap<AnfNodePtr,AnfNodePtr> & save_map)332 nlohmann::json SaveAnfToAnfMap(const HashMap<AnfNodePtr, AnfNodePtr> &save_map) {
333   nlohmann::json ret;
334   for (const auto &i : save_map) {
335     const auto &first_name = GetAnfUniqueCacheName(i.first, false);
336     const auto &second_name = GetAnfUniqueCacheName(i.second, false);
337     // allow some node not to be exported to mindir.
338     if (first_name.empty() || second_name.empty()) {
339       continue;
340     }
341     ret[first_name] = second_name;
342   }
343   return ret;
344 }
345 
SaveAnfToAnfIndexMap(const HashMap<AnfNodePtr,AnfWithOutIndex> & save_map)346 std::vector<nlohmann::json> SaveAnfToAnfIndexMap(const HashMap<AnfNodePtr, AnfWithOutIndex> &save_map) {
347   std::vector<nlohmann::json> ret_json;
348   for (const auto &i : save_map) {
349     nlohmann::json iter_json;
350     const auto &first_name = GetAnfUniqueCacheName(i.first, false);
351     const auto &second_name = GetAnfUniqueCacheName(i.second.first, false);
352     // allow some node not to be exported to mindir.
353     if (first_name.empty() || second_name.empty()) {
354       continue;
355     }
356     iter_json.push_back(first_name);
357     iter_json.push_back(second_name);
358     iter_json.push_back(i.second.second);
359     (void)(ret_json.emplace_back(iter_json));
360   }
361   return ret_json;
362 }
363 
SaveAnfIndexToAnfIndexMap(const std::map<AnfWithOutIndex,AnfWithOutIndex> & save_map)364 std::vector<nlohmann::json> SaveAnfIndexToAnfIndexMap(const std::map<AnfWithOutIndex, AnfWithOutIndex> &save_map) {
365   std::vector<nlohmann::json> ret_json;
366   for (const auto &i : save_map) {
367     nlohmann::json iter_json;
368     const auto &first_name = GetAnfUniqueCacheName(i.first.first, false);
369     const auto &second_name = GetAnfUniqueCacheName(i.second.first, false);
370     // allow some node not to be exported to mindir.
371     if (first_name.empty() || second_name.empty()) {
372       continue;
373     }
374     iter_json.push_back(first_name);
375     iter_json.push_back(i.first.second);
376     iter_json.push_back(second_name);
377     iter_json.push_back(i.second.second);
378     (void)(ret_json.emplace_back(iter_json));
379   }
380   return ret_json;
381 }
382 
SaveValueSet(const HashSet<ValueNodePtr> & save_anfs)383 nlohmann::json SaveValueSet(const HashSet<ValueNodePtr> &save_anfs) {
384   nlohmann::json iter_json;
385   for (const auto &i : save_anfs) {
386     const auto &name = GetAnfUniqueCacheName(i, false);
387     // allow some value node not to be exported to mindir.
388     if (name.empty()) {
389       continue;
390     }
391     (void)(iter_json.emplace_back(name));
392   }
393   return iter_json;
394 }
395 
396 template <typename T>
SaveAnfVec(const std::vector<T> & save_anfs)397 nlohmann::json SaveAnfVec(const std::vector<T> &save_anfs) {
398   nlohmann::json ret_json;
399   for (const auto &i : save_anfs) {
400     const auto &name = GetAnfUniqueCacheName(i);
401     (void)(ret_json.emplace_back(name));
402   }
403   return ret_json;
404 }
405 
SaveGraphVec(const std::vector<std::weak_ptr<KernelGraph>> & save_anfs)406 nlohmann::json SaveGraphVec(const std::vector<std::weak_ptr<KernelGraph>> &save_anfs) {
407   nlohmann::json ret_json;
408   for (const auto &i : save_anfs) {
409     const auto &ptr = i.lock();
410     MS_EXCEPTION_IF_NULL(ptr);
411     (void)(ret_json.emplace_back(ptr->graph_id()));
412   }
413   return ret_json;
414 }
415 
SaveGraphsId(const HashMap<uint32_t,std::weak_ptr<session::KernelGraph>> & to_save)416 nlohmann::json SaveGraphsId(const HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> &to_save) {
417   nlohmann::json ret_json;
418   for (const auto &i : to_save) {
419     nlohmann::json iter_json;
420     const auto &ptr = i.second.lock();
421     MS_EXCEPTION_IF_NULL(ptr);
422     ret_json.push_back(ptr->graph_id());
423   }
424   return ret_json;
425 }
426 
SavePrevOutputs(const std::map<size_t,std::pair<AnfNodeWeakPtr,size_t>> & save_map)427 std::vector<nlohmann::json> SavePrevOutputs(const std::map<size_t, std::pair<AnfNodeWeakPtr, size_t>> &save_map) {
428   std::vector<nlohmann::json> ret_json;
429   for (const auto &i : save_map) {
430     nlohmann::json iter_json;
431     const auto &node = i.second.first.lock();
432     MS_EXCEPTION_IF_NULL(node);
433     const auto &name = GetAnfUniqueCacheName(node, false);
434     if (name.empty()) {
435       continue;
436     }
437     iter_json.push_back(i.first);
438     iter_json.push_back(name);
439     iter_json.push_back(i.second.second);
440     ret_json.push_back(iter_json);
441   }
442   return ret_json;
443 }
444 
SaveKernelInfoRuntimeCache(KernelInfoDevice * kernel_info,nlohmann::json * const single_json)445 void SaveKernelInfoRuntimeCache(KernelInfoDevice *kernel_info, nlohmann::json *const single_json) {
446   MS_EXCEPTION_IF_NULL(kernel_info);
447   MS_EXCEPTION_IF_NULL(single_json);
448   const auto &rt = kernel_info->runtime_cache().runtime_cache();
449   if (!rt.is_valid()) {
450     return;
451   }
452   (*single_json)[kRuntimeCacheValid] = rt.is_valid();
453   (*single_json)[kRuntimeCacheDeviceTarget] = rt.device_target();
454   (*single_json)[kRuntimeCacheOutputTensorNum] = rt.output_tensor_num();
455   (*single_json)[kRuntimeCacheIsRealKernel] = rt.is_real_kernel();
456   const auto &prev_outputs_json = SavePrevOutputs(rt.GetPrevOutputs());
457   if (!prev_outputs_json.empty()) {
458     (*single_json)[kRuntimeCachePrevOutputs] = prev_outputs_json;
459   }
460 }
461 
SaveAnfKernelInfo(const AnfNodePtr & node)462 nlohmann::json SaveAnfKernelInfo(const AnfNodePtr &node) {
463   nlohmann::json single_json;
464   if (AnfUtils::IsRealKernel(node)) {
465     single_json[kOriginDataFormat] = AnfAlgo::GetOriginDataFormat(node);
466     const auto &input_formats = AnfAlgo::GetAllInputFormats(node);
467     if (!input_formats.empty()) {
468       single_json[kAllInputFormat] = input_formats;
469     }
470     const auto &output_formats = AnfAlgo::GetAllOutputFormats(node);
471     if (!output_formats.empty()) {
472       single_json[kAllOutputFormat] = output_formats;
473     }
474     const auto &input_device_types = AnfAlgo::GetAllInputDeviceTypes(node);
475     if (!input_device_types.empty()) {
476       single_json[kAllInputDeviceType] = input_device_types;
477     }
478     const auto &output_device_types = AnfAlgo::GetAllOutputDeviceTypes(node);
479     if (!output_device_types.empty()) {
480       single_json[kAllOutputDeviceType] = output_device_types;
481     }
482   }
483   if (AnfAlgo::HasSelectKernelBuildInfo(node)) {
484     auto kernel_type = AnfAlgo::GetKernelType(node);
485     single_json[kKernelType] = kernel_type;
486     auto op_type = AnfAlgo::GetOpType(node);
487     single_json[kOpType] = op_type;
488     single_json[kCoreType] = AnfAlgo::GetCoreType(node);
489     single_json[kOpPattern] = AnfAlgo::GetOpPattern(node);
490     const auto &input_reshape_types_json = AnfAlgo::GetAllInputReshapeType(node);
491     if (!input_reshape_types_json.empty()) {
492       single_json[kAllInputReshapeType] = input_reshape_types_json;
493     }
494     const auto &output_reshape_types_json = AnfAlgo::GetAllOutputReshapeType(node);
495     if (!output_reshape_types_json.empty()) {
496       single_json[kAllOutputReshapeType] = output_reshape_types_json;
497     }
498     const auto &input_kernel_object_types_json = AnfAlgo::GetInputKernelObjectTypes(node);
499     if (!input_kernel_object_types_json.empty()) {
500       single_json[kInputKernelObjectTypes] = input_kernel_object_types_json;
501     }
502     const auto &output_kernel_object_types_json = AnfAlgo::GetOutputKernelObjectTypes(node);
503     if (!output_kernel_object_types_json.empty()) {
504       single_json[kOutputKernelObjectTypes] = output_kernel_object_types_json;
505     }
506     const auto &output_elements_kernel_object_types_json = AnfAlgo::GetOutputElementsKernelObjectTypes(node);
507     if (!output_elements_kernel_object_types_json.empty()) {
508       single_json[kOutputElementsKernelObjectTypes] = output_elements_kernel_object_types_json;
509     }
510 
511     const auto &output_desc_json = AnfAlgo::GetOutputDataDesc(node);
512     if (!output_desc_json.empty()) {
513       single_json[kOutputDataDesc] = output_desc_json;
514     }
515     single_json[kFusionType] = AnfAlgo::GetFusionType(node);
516     single_json[kProcessor] = AnfAlgo::GetProcessor(node);
517     single_json[kKernelBuildInfoValid] = AnfAlgo::GetValid(node);
518     single_json[kHasSelectKernelBuildInfo] = true;
519   }
520   const auto &kernel_info = node->kernel_info();
521   MS_EXCEPTION_IF_NULL(kernel_info);
522   const auto &device_kernel_info = dynamic_cast<device::KernelInfo *>(kernel_info);
523   MS_EXCEPTION_IF_NULL(device_kernel_info);
524   nlohmann::json out_in_ref_json;
525   const auto &out_in_ref = device_kernel_info->out_in_ref_map();
526   (void)(std::for_each(out_in_ref.begin(), out_in_ref.end(),
527                        [&out_in_ref_json](const auto &iter) { out_in_ref_json[iter.first] = iter.second; }));
528   if (!out_in_ref_json.empty()) {
529     single_json[kOutInRef] = out_in_ref_json;
530   }
531   const auto &graph_id = device_kernel_info->graph_id();
532   single_json[kGraphId] = graph_id;
533   const auto &is_feature_map = device_kernel_info->is_feature_map();
534   single_json[kIsFeatureMap] = is_feature_map;
535 
536   if (node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, node->cast<CNodePtr>()) &&
537       common::AnfAlgo::GetNodeAttr<bool>(node->cast<CNodePtr>(), kAttrIsUBFusionOp)) {
538     auto &context = CompileCacheContext::GetInstance();
539     const auto &io_size = context.GetIOSizeInfo(node->fullname_with_scope());
540     single_json[kJsonName] = io_size.json_name;
541     const auto input_size_list_json = io_size.input_size_list;
542     if (!input_size_list_json.empty()) {
543       single_json[kInputSizeList] = input_size_list_json;
544     }
545     const auto output_size_list_json = io_size.output_size_list;
546     if (!output_size_list_json.empty()) {
547       single_json[kOutputSizeList] = output_size_list_json;
548     }
549   }
550   SaveKernelInfoRuntimeCache(kernel_info, &single_json);
551   return single_json;
552 }
553 
SaveBackendParamToFrontendParamIndex(const KernelGraphPtr & kernel_graph,const FuncGraph * front_graph)554 nlohmann::json SaveBackendParamToFrontendParamIndex(const KernelGraphPtr &kernel_graph, const FuncGraph *front_graph) {
555   nlohmann::json ret;
556   const auto &params = kernel_graph->parameters();
557   auto &context = CompileCacheContext::GetInstance();
558   const auto &front_params = front_graph->parameters();
559   for (const auto &param : params) {
560     if (!context.IsBackendParamGenFromFrontendParam(param)) {
561       continue;
562     }
563     const auto &name = param->user_data<std::string>(kUniqueCacheName);
564     MS_EXCEPTION_IF_NULL(name);
565     const auto &front_param = kernel_graph->GetFrontAnfByBackendAnf(param);
566     MS_EXCEPTION_IF_NULL(front_param);
567     auto iter = std::find(front_params.begin(), front_params.end(), front_param);
568     if (iter == front_params.end()) {
569       MS_LOG(EXCEPTION) << "Backend param " << param->DebugString() << " correspond frontend param "
570                         << front_param->DebugString() << " can not find in frontend graph params.";
571     }
572     ret[*name] = std::distance(front_params.begin(), iter);
573   }
574   return ret;
575 }
576 
SaveNodesKernelInfoAndParamsName(const KernelGraphPtr & kg,const std::vector<AnfNodePtr> & isolated_nodes,nlohmann::json * const kg_json)577 void SaveNodesKernelInfoAndParamsName(const KernelGraphPtr &kg, const std::vector<AnfNodePtr> &isolated_nodes,
578                                       nlohmann::json *const kg_json) {
579   std::vector<AnfNodePtr> nodes = TopoSort(kg->get_return(), SuccIncoming, AlwaysInclude);
580   nlohmann::json kernels_info_json;
581   (void)(nodes.insert(nodes.end(), isolated_nodes.begin(), isolated_nodes.end()));
582   const auto &params = kg->parameters();
583   std::vector<AnfNodePtr> isolated_params;
584   (void)(std::set_difference(params.begin(), params.end(), nodes.begin(), nodes.end(),
585                              std::back_inserter(isolated_params)));
586   (void)(nodes.insert(nodes.end(), isolated_params.begin(), isolated_params.end()));
587   nlohmann::json param_unique_name_to_name;
588   for (const auto &node : nodes) {
589     MS_EXCEPTION_IF_NULL(node);
590     if (node->kernel_info() == nullptr && !node->isa<Parameter>()) {
591       continue;
592     }
593     const auto &name = GetAnfUniqueCacheName(node);
594     if (node->isa<Parameter>()) {
595       auto param = node->cast<ParameterPtr>();
596       param_unique_name_to_name[name] = param->name();
597     }
598     if (node->kernel_info() == nullptr) {
599       MS_LOG(WARNING) << "The node " << node->DebugString() << " has not kernel_info.";
600       continue;
601     }
602     const auto &kernel_info_json = SaveAnfKernelInfo(node);
603     if (!kernel_info_json.empty()) {
604       kernels_info_json[name] = kernel_info_json;
605     }
606   }
607   (*kg_json)[kParameterUniqueNameToName] = param_unique_name_to_name;
608   (*kg_json)[kNodesKernelInfo] = kernels_info_json;
609 }
610 
SaveSummaryNodes(const std::map<std::string,std::pair<AnfNodePtr,int>> & save_map)611 std::vector<nlohmann::json> SaveSummaryNodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &save_map) {
612   std::vector<nlohmann::json> ret_json;
613   for (const auto &i : save_map) {
614     nlohmann::json iter_json;
615     const auto &first = i.first;
616     const auto &node = i.second.first;
617     const auto &name = GetAnfUniqueCacheName(node);
618     const auto &index = i.second.second;
619     iter_json.push_back(first);
620     iter_json.push_back(name);
621     iter_json.push_back(index);
622     ret_json.push_back(iter_json);
623   }
624   return ret_json;
625 }
626 
GenKernelGraphJson(const KernelGraphPtr & kg,const std::vector<AnfNodePtr> & isolated_nodes)627 nlohmann::json GenKernelGraphJson(const KernelGraphPtr &kg, const std::vector<AnfNodePtr> &isolated_nodes) {
628   nlohmann::json kg_json;
629   SaveNodesKernelInfoAndParamsName(kg, isolated_nodes, &kg_json);
630   kg_json[kGraphId] = kg->graph_id();
631   kg_json[kRunMode] = kg->RunMode();
632   kg_json[kIsLoopCountSink] = kg->is_loop_count_sink();
633   kg_json[kIsDynamicShape] = kg->is_dynamic_shape();
634   kg_json[kDeviceTarget] = kg->device_target();
635   kg_json[kRootGraphId] = kg->root_graph_id();
636   kg_json[kExecutable] = kg->executable();
637   kg_json[kHasRecursiveCall] = kg->recursive_call();
638   kg_json[kHasSubgraphMultiCall] = kg->subgraph_multi_call();
639   kg_json[kNeedInline] = kg->need_inline();
640   kg_json[kIsNeedGil] = kg->is_need_gil();
641   kg_json[kIsFromSingleOp] = kg->is_from_single_op();
642   kg_json[kLabelNum] = kg->label_num();
643 #ifndef ENABLE_SECURITY
644   kg_json[kSummaryNodeExist] = kg->summary_node_exist();
645 #endif
646   const auto &back_front_anf_json = SaveAnfToAnfMap(kg->backend_front_anf_map());
647   if (!back_front_anf_json.empty()) {
648     kg_json[kBackendFrontAnf] = back_front_anf_json;
649   }
650   const auto &internal_params_to_front_node_json = SaveAnfToAnfIndexMap(kg->InternalParameterToFrontNodeMap());
651   if (!internal_params_to_front_node_json.empty()) {
652     kg_json[kInternalParameterToFrontNode] = internal_params_to_front_node_json;
653   }
654   const auto &ref_in_out_map_json = SaveAnfIndexToAnfIndexMap(kg->GetRefMap());
655   if (!ref_in_out_map_json.empty()) {
656     kg_json[kRefInOutMap] = ref_in_out_map_json;
657   }
658   const auto &graph_value_nodes = SaveValueSet(kg->graph_value_nodes());
659   if (!graph_value_nodes.empty()) {
660     kg_json[kGraphValueNodes] = graph_value_nodes;
661   }
662   const auto &exec_order_json = SaveAnfVec(kg->execution_order());
663   if (!exec_order_json.empty()) {
664     kg_json[kExecutionOrder] = exec_order_json;
665   }
666   const auto &inputs_json = SaveAnfVec(kg->inputs());
667   if (!inputs_json.empty()) {
668     kg_json[kInputs] = inputs_json;
669   }
670   const auto &parameters_json = SaveAnfVec(kg->parameters());
671   if (!parameters_json.empty()) {
672     kg_json[kParameters] = parameters_json;
673   }
674   const auto &child_graph_result_json = SaveAnfVec(kg->child_graph_result());
675   if (!child_graph_result_json.empty()) {
676     kg_json[kChildGraphResult] = child_graph_result_json;
677   }
678   const auto &child_graph_order_json = SaveGraphVec(kg->child_graph_order());
679   if (!child_graph_order_json.empty()) {
680     kg_json[kChildGraphOrder] = child_graph_order_json;
681   }
682   const auto &start = kg->get_start_label();
683   if (start) {
684     kg_json[kStartLabel] = GetAnfUniqueCacheName(start);
685   }
686   const auto &end = kg->get_end_goto();
687   if (end) {
688     kg_json[kEndGoto] = GetAnfUniqueCacheName(end);
689   }
690   const auto &valid_inputs = kg->valid_inputs();
691   if (!valid_inputs.empty()) {
692     kg_json[kValidInputs] = valid_inputs;
693   }
694   const auto &pre_graphs_json = SaveGraphsId(kg->get_pre_graphs());
695   if (!pre_graphs_json.empty()) {
696     kg_json[kPreGraphs] = pre_graphs_json;
697   }
698   const auto &post_graphs_json = SaveGraphsId(kg->GetPostGraphs());
699   if (!post_graphs_json.empty()) {
700     kg_json[kPostGraphs] = post_graphs_json;
701   }
702   const auto &index_set = kg->CommSubGraphIds();
703   if (!index_set.empty()) {
704     kg_json[kCommSubGraphIds] = index_set;
705   }
706 #ifndef ENABLE_SECURITY
707   const auto &summary_nodes_json = SaveSummaryNodes(kg->summary_nodes());
708   if (!summary_nodes_json.empty()) {
709     kg_json[kSummaryNodes] = summary_nodes_json;
710   }
711 #endif
712   auto &context = CompileCacheContext::GetInstance();
713   auto front_graph = context.GetFrontendGraphByBackendGraph(kg);
714   if (front_graph) {
715     kg_json[kCorrespondFrontendGraph] = front_graph->ToString();
716   }
717   kg_json[kBackendParamToFrontendParamIndex] = SaveBackendParamToFrontendParamIndex(kg, front_graph);
718   return kg_json;
719 }
720 
DumpKernelGraphJson(const KernelGraphPtr & root_graph,const std::set<KernelGraphPtr> & child_graphs,const std::map<KernelGraphPtr,std::vector<AnfNodePtr>> & isolated_nodes_map,const std::string & path)721 bool DumpKernelGraphJson(const KernelGraphPtr &root_graph, const std::set<KernelGraphPtr> &child_graphs,
722                          const std::map<KernelGraphPtr, std::vector<AnfNodePtr>> &isolated_nodes_map,
723                          const std::string &path) {
724   nlohmann::json kg_json;
725   kg_json[root_graph->ToString()] = GenKernelGraphJson(root_graph, isolated_nodes_map.find(root_graph)->second);
726   for (const auto &graph : child_graphs) {
727     kg_json[graph->ToString()] = GenKernelGraphJson(graph, isolated_nodes_map.find(graph)->second);
728   }
729   return Common::SaveStringToFile(path, kg_json.dump());
730 }
731 
GetAllChildGraph(const KernelGraphPtr & kg,std::set<KernelGraphPtr> * visit,std::set<KernelGraphPtr> * graphs)732 void GetAllChildGraph(const KernelGraphPtr &kg, std::set<KernelGraphPtr> *visit, std::set<KernelGraphPtr> *graphs) {
733   if (kg == nullptr || kg->IsLeafGraph()) {
734     return;
735   }
736   MS_EXCEPTION_IF_NULL(visit);
737   MS_EXCEPTION_IF_NULL(graphs);
738   if (visit->find(kg) != visit->end()) {
739     return;
740   }
741   const auto &order = kg->child_graph_order();
742   for (auto iter : order) {
743     auto graph = iter.lock();
744     MS_EXCEPTION_IF_NULL(graph);
745     (void)(graphs->insert(graph));
746   }
747   (void)(visit->insert(kg));
748 
749   for (auto &i : order) {
750     GetAllChildGraph(i.lock(), visit, graphs);
751   }
752 }
753 
GetIsolatedNodes(const KernelGraphPtr & kg,std::vector<AnfNodePtr> * isolated_nodes)754 void GetIsolatedNodes(const KernelGraphPtr &kg, std::vector<AnfNodePtr> *isolated_nodes) {
755   MS_EXCEPTION_IF_NULL(kg);
756   const auto &orders = kg->execution_order();
757   std::vector<AnfNodePtr> possible_isolated(orders.begin(), orders.end());
758   const auto &start = kg->get_start_label();
759   if (start && std::find(possible_isolated.begin(), possible_isolated.end(), start) == possible_isolated.end()) {
760     possible_isolated.push_back(start);
761   }
762   const auto &end = kg->get_end_goto();
763   if (end && std::find(possible_isolated.begin(), possible_isolated.end(), end) == possible_isolated.end()) {
764     possible_isolated.push_back(end);
765   }
766   auto topo_nodes = TopoSort(kg->get_return(), SuccIncoming, AlwaysInclude);
767   (void)(std::set_difference(possible_isolated.begin(), possible_isolated.end(), topo_nodes.begin(), topo_nodes.end(),
768                              std::back_inserter(*isolated_nodes)));
769 }
770 
HandleParamExistCorrespondFrontendParam(const KernelGraphPtr & graph)771 void HandleParamExistCorrespondFrontendParam(const KernelGraphPtr &graph) {
772   MS_EXCEPTION_IF_NULL(graph);
773   auto &context = CompileCacheContext::GetInstance();
774   const auto &front_graph = context.GetFrontendGraphByBackendGraph(graph);
775   if (!front_graph) {
776     return;
777   }
778   const auto &params = graph->parameters();
779   const auto &front_params = front_graph->parameters();
780   for (const auto &param : params) {
781     auto front_param = graph->GetFrontAnfByBackendAnf(param);
782     if (!front_param) {
783       continue;
784     }
785     auto iter = std::find(front_params.begin(), front_params.end(), front_param);
786     if (iter != front_params.end()) {
787       context.InsertBackendParamGenFromFrontendParam(param);
788     }
789   }
790 }
791 
NeedConvertValueNodeToParameter(const AnfNodePtr & node)792 bool NeedConvertValueNodeToParameter(const AnfNodePtr &node) {
793   MS_EXCEPTION_IF_NULL(node);
794   auto ms_context = MsContext::GetInstance();
795   MS_EXCEPTION_IF_NULL(ms_context);
796   if (ms_context->backend_policy() != "ge" || ms_context->IsKByKExecutorMode()) {
797     return false;
798   }
799   if (!node->isa<ValueNode>()) {
800     return false;
801   }
802   auto value_node = node->cast<ValueNodePtr>();
803   MS_EXCEPTION_IF_NULL(value_node);
804   auto value = value_node->value();
805   MS_EXCEPTION_IF_NULL(value);
806   if (value->isa<tensor::Tensor>()) {
807     auto tensor = value->cast<tensor::TensorPtr>();
808     MS_EXCEPTION_IF_NULL(tensor);
809     if (tensor->is_forward_output()) {
810       return true;
811     }
812   }
813   return false;
814 }
815 
ConvertValueNodeToParameter(const KernelGraphPtr & graph,const AnfNodePtr & node,std::vector<ParameterPtr> * added_parameters)816 void ConvertValueNodeToParameter(const KernelGraphPtr &graph, const AnfNodePtr &node,
817                                  std::vector<ParameterPtr> *added_parameters) {
818   MS_EXCEPTION_IF_NULL(graph);
819   MS_EXCEPTION_IF_NULL(node);
820   MS_EXCEPTION_IF_NULL(added_parameters);
821   auto graph_inputs = graph->MutableInputs();
822   MS_EXCEPTION_IF_NULL(graph_inputs);
823   auto new_parameter = graph->NewParameter(node->abstract());
824   MS_EXCEPTION_IF_NULL(new_parameter);
825   new_parameter->IncreaseUsedGraphCount();
826   graph_inputs->push_back(new_parameter);
827 
828   MS_EXCEPTION_IF_NULL(node->cast<ValueNodePtr>());
829   new_parameter->set_user_data(kForwardOutput, node->cast<ValueNodePtr>()->value());
830   graph->FrontBackendMapAdd(node, new_parameter);
831   (void)added_parameters->emplace_back(new_parameter);
832   MS_LOG(DEBUG) << "Replace ValueNode " << node->DebugString() << " with Parameter " << new_parameter->DebugString();
833 }
834 }  // namespace
835 
GetParamDefaultValue(const AnfNodePtr & node)836 ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
837   if (node == nullptr) {
838     return nullptr;
839   }
840   auto parameter = node->cast<ParameterPtr>();
841   if (parameter == nullptr || !parameter->has_default()) {
842     return nullptr;
843   }
844   return parameter->param_info();
845 }
846 
847 #ifndef ENABLE_SECURITY
ExistSummaryNode(const KernelGraph * graph)848 bool ExistSummaryNode(const KernelGraph *graph) {
849   MS_EXCEPTION_IF_NULL(graph);
850   for (auto &n : TopoSort(graph->get_return())) {
851     if (AnfAlgo::IsSummaryNode(n)) {
852       return true;
853     }
854   }
855   return false;
856 }
857 #endif
858 
859 GraphId KernelGraphMgr::graph_sum_ = 0;
860 GraphId KernelGraphMgr::pynative_graph_sum_ = kPynativeGraphIdStart;
861 HashMap<std::string, std::weak_ptr<AnfNode>> KernelGraphMgr::name_to_params_ =
862   HashMap<std::string, std::weak_ptr<AnfNode>>();
863 
CreateNewValueNode(const AnfNodePtr & anf,KernelGraph * graph) const864 ValueNodePtr KernelGraphMgr::CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) const {
865   MS_EXCEPTION_IF_NULL(anf);
866   MS_EXCEPTION_IF_NULL(graph);
867   auto value_node = anf->cast<ValueNodePtr>();
868   MS_EXCEPTION_IF_NULL(value_node);
869   auto value = value_node->value();
870   MS_EXCEPTION_IF_NULL(value);
871   // Copy data from device if the tensor is an output of Op or Graph.
872   if (value->isa<tensor::Tensor>()) {
873     auto tensor = value->cast<TensorPtr>();
874     MS_EXCEPTION_IF_NULL(tensor);
875     if (!tensor->is_forward_output() && !tensor->is_parameter()) {
876       tensor->data_sync();
877       MS_LOG(INFO) << "Data sync for Tensor " << tensor->ToString();
878     }
879   }
880   auto new_value_node = value_node;
881   if (!graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
882     new_value_node = graph->NewValueNode(value_node);
883     graph->FrontBackendMapAdd(anf, new_value_node);
884   }
885   graph->AddValueNodeToGraph(new_value_node);
886   return new_value_node;
887 }
888 
GetGraphIdByNode(const AnfNodePtr & front_anf) const889 GraphId KernelGraphMgr::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
890   for (const auto &graph_item : graphs_) {
891     auto graph = graph_item.second;
892     MS_EXCEPTION_IF_NULL(graph);
893     // if front_anf is a parameter,the backend parameter may have two
894     if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
895       return graph_item.first;
896     }
897   }
898   MS_EXCEPTION_IF_NULL(front_anf);
899   MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
900   return kInvalidGraphId;
901 }
902 
GetGraph(mindspore::GraphId graph_id) const903 KernelGraphPtr KernelGraphMgr::GetGraph(mindspore::GraphId graph_id) const {
904   auto it = graphs_.find(graph_id);
905   if (it == graphs_.end()) {
906     MS_LOG(INFO) << "Can't find graph " << graph_id;
907     return nullptr;
908   }
909   return it->second;
910 }
911 
ClearGraph()912 void KernelGraphMgr::ClearGraph() {
913   auto graph_iter = graphs_.begin();
914   while (graph_iter != graphs_.end()) {
915     graph_iter->second.reset();
916     graph_iter = graphs_.erase(graph_iter);
917   }
918   graph_sum_ = 0;
919   pynative_graph_sum_ = kPynativeGraphIdStart;
920 }
921 
InitInternalOutputParameter(const AnfNodePtr & out_node,const AnfNodePtr & parameter) const922 void KernelGraphMgr::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) const {
923   MS_EXCEPTION_IF_NULL(out_node);
924   MS_EXCEPTION_IF_NULL(parameter);
925   MS_LOG(DEBUG) << "parameter:" << parameter->DebugString()
926                 << " abstract:" << (parameter->abstract() != nullptr ? parameter->abstract()->ToString() : "null");
927   auto graph_id = GetGraphIdByNode(out_node);
928   if (graph_id == kInvalidGraphId) {
929     return;
930   }
931   auto node_graph = GetGraph(graph_id);
932   if (node_graph == nullptr) {
933     return;
934   }
935   MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
936   auto ref_node_with_index = node_graph->GetInternalOutputByFrontNode(out_node);
937   auto ref_node = ref_node_with_index.first;
938   if (ref_node == nullptr) {
939     MS_LOG(INFO) << "No corresponding internal output for output node";
940     return;
941   }
942   size_t output_idx = ref_node_with_index.second;
943   if (common::AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
944     output_idx = common::AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
945   }
946   auto real_kernel = common::AnfAlgo::VisitKernel(ref_node, output_idx);
947   auto ref_real_node = real_kernel.first;
948   MS_EXCEPTION_IF_NULL(ref_real_node);
949   auto ref_real_node_index = real_kernel.second;
950   if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
951     auto kernel_info = ref_real_node->kernel_info();
952     if (kernel_info == nullptr || !kernel_info->has_build_info()) {
953       MS_LOG(INFO) << "No kernel info";
954       return;
955     }
956     if (!common::AnfAlgo::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
957       MS_LOG(INFO) << "No kernel address";
958       return;
959     }
960     if (!AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index, true)) {
961       return;
962     }
963 
964     // Update the kernel build info.
965     auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
966     auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
967     if (type == TypeId::kTypeUnknown) {
968       return;
969     }
970     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
971     builder.SetOutputsDeviceType({type});
972     builder.SetOutputsFormat({format});
973     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), parameter.get());
974 
975     abstract::AbstractBasePtr abstract;
976     auto shape = parameter->Shape();
977     MS_EXCEPTION_IF_NULL(shape);
978     if (shape->isa<abstract::NoShape>()) {
979       abstract = std::make_shared<abstract::AbstractScalar>(TypeIdToType(type));
980     } else if (shape->isa<abstract::DynamicSequenceShape>()) {
981       return;
982     } else {
983       abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape->cast<abstract::BaseShapePtr>());
984     }
985     if (!parameter->abstract()->isa<abstract::AbstractAny>()) {
986       parameter->set_abstract(abstract);
987     }
988   }
989 }
990 
CreateParameterFromTuple(const AnfNodePtr & node,KernelGraph * graph) const991 AnfNodePtr KernelGraphMgr::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) const {
992   MS_EXCEPTION_IF_NULL(node);
993   MS_EXCEPTION_IF_NULL(graph);
994   auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
995   auto parameters = common::AnfAlgo::GetAllOutput(new_parameter);
996   std::vector<AnfNodePtr> pre_graph_out = {node};
997   // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
998   if (!pre_graph_out.empty() && !AnfUtils::IsRealKernel(node)) {
999     pre_graph_out = common::AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
1000   }
1001 
1002   for (size_t i = 0; i < parameters.size(); ++i) {
1003     const auto &parameter = parameters[i];
1004     auto context_ptr = MsContext::GetInstance();
1005     MS_EXCEPTION_IF_NULL(context_ptr);
1006     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1007       // In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
1008       // which needs to be linked when processing the internal node.
1009       graph->CacheInternalParameterToFrontNode(parameter, {node, i});
1010     }
1011     auto valid_inputs = graph->MutableValidInputs();
1012     MS_EXCEPTION_IF_NULL(valid_inputs);
1013     auto graph_inputs = graph->MutableInputs();
1014     MS_EXCEPTION_IF_NULL(graph_inputs);
1015     valid_inputs->push_back(true);
1016     graph_inputs->push_back(parameter);
1017   }
1018   size_t param_index = 0;
1019   for (const auto &out_node : pre_graph_out) {
1020     size_t output_size = AnfAlgo::GetOutputElementNum(out_node);
1021     for (size_t i = 0; i < output_size; i++) {
1022       if (param_index >= parameters.size()) {
1023         MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
1024                           << ",out_node:" << out_node->DebugString();
1025       }
1026       InitInternalOutputParameter(out_node, parameters[param_index++]);
1027     }
1028   }
1029   return new_parameter;
1030 }
1031 
CreateNewParameterFromParameter(const AnfNodePtr & anf,KernelGraph * graph)1032 ParameterPtr KernelGraphMgr::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
1033   MS_EXCEPTION_IF_NULL(anf);
1034   if (!anf->isa<Parameter>()) {
1035     MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
1036   }
1037   MS_EXCEPTION_IF_NULL(graph);
1038   auto param_value = GetParamDefaultValue(anf);
1039   auto valid_inputs = graph->MutableValidInputs();
1040   MS_EXCEPTION_IF_NULL(valid_inputs);
1041   auto graph_inputs = graph->MutableInputs();
1042   MS_EXCEPTION_IF_NULL(graph_inputs);
1043   ParameterPtr new_parameter = nullptr;
1044   auto func_graph = anf->func_graph();
1045   MS_EXCEPTION_IF_NULL(func_graph);
1046   bool is_pynative_bprop_kernel_graph = graph->has_flag(kFlagIsPyNativeBpropKernelGraph);
1047   if (func_graph->manager() != nullptr && func_graph->exist_multi_target() &&
1048       graph->device_target() == device::DeviceType::kCPU) {
1049     auto iter = default_param_map_.find(anf);
1050     if (iter != default_param_map_.end()) {
1051       new_parameter = iter->second;
1052     }
1053     if (new_parameter != nullptr) {
1054       graph_inputs->push_back(new_parameter);
1055       MS_LOG(DEBUG) << "create new parameter for parameter:" << anf->DebugString() << " for graph:" << graph->ToString()
1056                     << " backend node:" << new_parameter->DebugString();
1057       return new_parameter;
1058     }
1059     TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1060     new_parameter = anf->cast<ParameterPtr>();
1061     if (!is_pynative_bprop_kernel_graph) {
1062       new_parameter = graph->NewParameter(new_parameter);
1063     }
1064     graph_inputs->push_back(new_parameter);
1065     valid_inputs->push_back(true);
1066     default_param_map_[anf] = new_parameter;
1067     return new_parameter;
1068   }
1069   // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
1070   if (!is_pynative_bprop_kernel_graph) {
1071     auto context = MsContext::GetInstance();
1072     if (!context->IsKByKExecutorMode() && param_value != nullptr) {
1073       new_parameter = param_value->parameter();
1074     }
1075     if (new_parameter == nullptr) {
1076       TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1077       new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1078 
1079       auto input_node_iter = partial_parameters_map_.find(anf);
1080       if (input_node_iter != partial_parameters_map_.end()) {
1081         InitInternalOutputParameter(input_node_iter->second, new_parameter);
1082       }
1083 
1084       if (param_value != nullptr) {
1085         param_value->set_parameter(new_parameter);
1086       }
1087     }
1088     new_parameter->IncreaseUsedGraphCount();
1089   } else {
1090     new_parameter = anf->cast<ParameterPtr>();
1091   }
1092   (void)graph_inputs->emplace_back(new_parameter);
1093   (void)valid_inputs->emplace_back(true);
1094   return new_parameter;
1095 }
1096 
CreateNewParameterFromCNode(const AnfNodePtr & anf,KernelGraph * graph)1097 AnfNodePtr KernelGraphMgr::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
1098   MS_EXCEPTION_IF_NULL(anf);
1099   MS_EXCEPTION_IF_NULL(graph);
1100   MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
1101   return CreateParameterFromTuple(anf, graph);
1102 }
1103 
GetCNodeInfo(const CNodePtr & cnode,std::vector<AnfNodePtr> * cnode_inputs) const1104 void KernelGraphMgr::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const {
1105   MS_EXCEPTION_IF_NULL(cnode);
1106   MS_EXCEPTION_IF_NULL(cnode_inputs);
1107   auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
1108   if (prim != nullptr) {
1109     // push attr to inputs[0] of new cnode
1110     cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
1111   } else {
1112     auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
1113     MS_EXCEPTION_IF_NULL(fg);
1114     auto new_fg = BasicClone(fg);
1115     cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
1116   }
1117 }
1118 
GetChildGraph(KernelGraph * graph,const AnfNodePtr & child_func_graph)1119 AnfNodePtr KernelGraphMgr::GetChildGraph(KernelGraph *graph, const AnfNodePtr &child_func_graph) {
1120   MS_EXCEPTION_IF_NULL(child_func_graph);
1121   std::vector<KernelGraphPtr> all_graphs;
1122   FuncGraphPtr child_graph = common::AnfAlgo::GetValueNodeFuncGraph(child_func_graph);
1123   MS_EXCEPTION_IF_NULL(child_graph);
1124   if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
1125     (void)ConstructKernelGraph(child_graph, &all_graphs, graph->device_target());
1126   }
1127   auto new_value_node = graph->GetBackendAnfByFrontAnf(child_func_graph);
1128   if (new_value_node != nullptr) {
1129     return new_value_node;
1130   }
1131   new_value_node = CreateValueNodeKernelGraph(child_func_graph, graph);
1132   MS_EXCEPTION_IF_NULL(new_value_node);
1133   return new_value_node;
1134 }
1135 
1136 namespace {
AddValueNode(const AnfNodePtr & backend_node,KernelGraph * graph)1137 void AddValueNode(const AnfNodePtr &backend_node, KernelGraph *graph) {
1138   if (backend_node->isa<ValueNode>() && !IsValueNode<FuncGraph>(backend_node)) {
1139     graph->AddValueNodeToGraph(backend_node->cast<ValueNodePtr>());
1140   }
1141 }
1142 }  // namespace
1143 
GetNewCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * other_graph_cnode)1144 void KernelGraphMgr::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
1145                                        mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
1146   MS_EXCEPTION_IF_NULL(cnode);
1147   MS_EXCEPTION_IF_NULL(graph);
1148   MS_EXCEPTION_IF_NULL(other_graph_cnode);
1149   MS_EXCEPTION_IF_NULL(cnode_inputs);
1150   auto origin_inputs = cnode->inputs();
1151   const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
1152   auto context = MsContext::GetInstance();
1153   MS_EXCEPTION_IF_NULL(context);
1154   const bool enable_ge = context->backend_policy() == "ge";
1155   AnfNodePtr child_func_graph = nullptr;
1156   std::vector<AnfNodePtr> params;
1157   // if has multiple depends,only select first depend as parameter
1158   for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
1159     auto anf = origin_inputs[input_idx];
1160     MS_EXCEPTION_IF_NULL(anf);
1161     // anf has been created before
1162     if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
1163       const auto &backend_node = graph->GetBackendAnfByFrontAnf(anf);
1164       (void)params.emplace_back(backend_node);
1165       AddValueNode(backend_node, graph);
1166       continue;
1167     } else if ((is_depend && input_idx > kRealInputIndexInDepend && !enable_ge)) {
1168       (void)params.emplace_back(graph->NewValueNode(std::make_shared<Tensor>(SizeToInt(input_idx))));
1169       continue;
1170     } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
1171       (void)params.emplace_back((*other_graph_cnode)[anf]);
1172       continue;
1173     } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
1174       // if input is a value node,
1175       auto new_value_node = CreateNewValueNode(anf, graph);
1176       if (new_value_node != nullptr) {
1177         (void)params.emplace_back(new_value_node);
1178       }
1179       continue;
1180     } else if (anf->isa<Parameter>()) {
1181       auto new_parameter = CreateNewParameterFromParameter(anf, graph);
1182       MS_EXCEPTION_IF_NULL(new_parameter);
1183       MS_LOG(DEBUG) << "Create new parameter:" << new_parameter->DebugString()
1184                     << " by front parameter:" << anf->DebugString();
1185       (void)params.emplace_back(new_parameter);
1186       graph->FrontBackendMapAdd(anf, new_parameter);
1187       continue;
1188     } else if (IsValueNode<FuncGraph>(anf) && cnode->HasPrimalAttr(kAttrNotCut)) {
1189       MS_EXCEPTION_IF_CHECK_FAIL(input_idx == 1, "Graph input index is not 1, anf: " + anf->DebugString() +
1190                                                    ", index: " + std::to_string(input_idx));
1191       child_func_graph = anf;
1192       continue;
1193     } else {
1194       // the input node is a cnode from other graph
1195       auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
1196       if (parameter_from_cnode == nullptr) {
1197         parameter_from_cnode = NewValueNode(MakeValue(SizeToLong(input_idx)));
1198       }
1199       MS_EXCEPTION_IF_NULL(parameter_from_cnode);
1200       MS_LOG(DEBUG) << "graph:" << graph->ToString() << " front node:" << anf->DebugString()
1201                     << " abstract:" << (anf->abstract() != nullptr ? anf->abstract()->ToString() : "null")
1202                     << " parameter:" << parameter_from_cnode->DebugString() << " abstract:"
1203                     << (parameter_from_cnode->abstract() != nullptr ? parameter_from_cnode->abstract()->ToString()
1204                                                                     : "null");
1205       if (parameter_from_cnode->isa<Parameter>() && IsPrimitiveCNode(anf, prim::kPrimLoad)) {
1206         auto para = parameter_from_cnode->cast<ParameterPtr>();
1207         auto load_cnode = anf->cast<CNodePtr>();
1208         para->set_name(load_cnode->fullname_with_scope());
1209       }
1210       (void)params.emplace_back(parameter_from_cnode);
1211       (*other_graph_cnode)[anf] = parameter_from_cnode;
1212     }
1213   }
1214 
1215   if (child_func_graph != nullptr) {
1216     (void)cnode_inputs->emplace_back(GetChildGraph(graph, child_func_graph));
1217   }
1218   (void)std::copy(params.begin(), params.end(), std::back_inserter(*cnode_inputs));
1219 }
1220 
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * other_graph_cnode)1221 CNodePtr KernelGraphMgr::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
1222                                         mindspore::HashMap<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
1223   MS_EXCEPTION_IF_NULL(cnode);
1224   MS_EXCEPTION_IF_NULL(graph);
1225   MS_EXCEPTION_IF_NULL(other_graph_cnode);
1226   auto primitive_input = cnode->input(kAnfPrimitiveIndex);
1227   // control flow sink to GE
1228   bool need_control_flow_sink =
1229     IsPrimitiveCNode(primitive_input, prim::kPrimSwitch) && cnode->HasPrimalAttr(kAttrNotCut);
1230   // backend inline
1231   bool need_backend_inline = cnode->HasPrimalAttr(kAttrNeedInline);
1232   if (need_backend_inline) {
1233     auto fn = cnode->input(kAnfPrimitiveIndex);
1234     MS_EXCEPTION_IF_NULL(fn);
1235     if (IsValueNode<FuncGraph>(fn)) {
1236       // Need to create a new kernel graph
1237       (void)GetChildGraph(graph, fn);
1238     }
1239   }
1240   if (need_control_flow_sink || need_backend_inline) {
1241     auto new_cnode = CreateNewCNode(cnode, graph);
1242     MS_EXCEPTION_IF_NULL(new_cnode);
1243     FlattenTuple(new_cnode);
1244     if (need_backend_inline) {
1245       new_cnode->AddPrimalAttr(kAttrNeedInline, MakeValue(true));
1246     }
1247     MS_LOG(DEBUG) << "Create new call node:" << new_cnode->DebugString() << " by front node:" << cnode->DebugString();
1248     return new_cnode;
1249   }
1250   // get primitive of old node
1251   std::vector<AnfNodePtr> cnode_inputs;
1252   GetCNodeInfo(cnode, &cnode_inputs);
1253   GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
1254   TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
1255   auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
1256   return new_cnode;
1257 }
1258 
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph)1259 CNodePtr KernelGraphMgr::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
1260   MS_EXCEPTION_IF_NULL(cnode);
1261   MS_EXCEPTION_IF_NULL(graph);
1262   std::vector<AnfNodePtr> cnode_inputs;
1263   auto attr_input = cnode->input(kAnfPrimitiveIndex);
1264   MS_EXCEPTION_IF_NULL(attr_input);
1265   if (IsValueNode<FuncGraph>(attr_input)) {
1266     // cnode is a graph or a call
1267     cnode_inputs = CreateValueNode(cnode, graph);
1268   } else if (attr_input->isa<CNode>()) {
1269     // cnode ia a call (partial/switch/switch_layer)
1270     // 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
1271     // 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
1272     cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
1273     if (cnode_inputs.empty()) {
1274       MS_LOG(ERROR) << "Create switch or partial failed, cnode:" << cnode->DebugString();
1275       return nullptr;
1276     }
1277   } else {
1278     // get primitive of old node
1279     auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
1280     MS_EXCEPTION_IF_NULL(prim);
1281     // push attr to inputs[0] of new cnode
1282     cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
1283   }
1284   // handle inputs of cnode except primitive
1285   CreateCNodeInputs(cnode, graph, &cnode_inputs);
1286   TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
1287   auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
1288   MS_EXCEPTION_IF_NULL(new_cnode);
1289   // if the cnode is call switch, remove call
1290   if (new_cnode->size() > 1) {
1291     auto first_input = new_cnode->input(kFirstDataInputIndex);
1292     MS_EXCEPTION_IF_NULL(first_input);
1293     if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1294         common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
1295       new_cnode = first_input->cast<CNodePtr>();
1296     }
1297     if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1298         common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
1299       auto abstract = cnode->abstract();
1300       new_cnode = first_input->cast<CNodePtr>();
1301       new_cnode->set_abstract(abstract);
1302     }
1303   }
1304   return new_cnode;
1305 }
1306 
CreateSwitchInput(const CNodePtr & cnode,const AnfNodePtr & node_input,KernelGraph * graph)1307 CNodePtr KernelGraphMgr::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
1308   MS_EXCEPTION_IF_NULL(node_input);
1309   MS_EXCEPTION_IF_NULL(graph);
1310   // switch input generalizes partial
1311   std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
1312   if (common::AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
1313     auto backend_node = graph->GetBackendAnfByFrontAnf(node_input);
1314     MS_EXCEPTION_IF_NULL(backend_node);
1315     return backend_node->cast<CNodePtr>();
1316   } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
1317     (void)(partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)));
1318   } else {
1319     KernelGraphPtr kernel_graph = NewKernelGraph();
1320     MS_EXCEPTION_IF_NULL(kernel_graph);
1321     auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
1322     MS_EXCEPTION_IF_NULL(parameter);
1323     MS_EXCEPTION_IF_NULL(cnode);
1324     parameter->set_abstract(cnode->abstract());
1325     auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
1326     auto return_node = kernel_graph->NewCNode({primitive, parameter});
1327     MS_EXCEPTION_IF_NULL(return_node);
1328     return_node->set_abstract(cnode->abstract());
1329     kernel_graph->set_return(return_node);
1330     (void)(partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)));
1331     (void)(partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)));
1332   }
1333   auto partial_node = graph->NewCNode(partial_inputs);
1334   return partial_node;
1335 }
1336 
CacheKernelGraph(const KernelGraphPtr & kg)1337 void KernelGraphMgr::CacheKernelGraph(const KernelGraphPtr &kg) {
1338   MS_EXCEPTION_IF_NULL(kg);
1339   auto &context = CompileCacheContext::GetInstance();
1340   auto fg = context.FrontGraph();
1341   if (!fg) {
1342     MS_LOG(EXCEPTION) << "The frontend graph to be cached is null";
1343   }
1344   if (!kg) {
1345     MS_LOG(EXCEPTION) << "The backend graph to be cached is null";
1346   }
1347   MS_LOG(INFO) << "Begin to cache kernel graph " << kg->ToString();
1348   std::set<KernelGraphPtr> visit;
1349   std::set<KernelGraphPtr> child_graphs;
1350   GetAllChildGraph(kg, &visit, &child_graphs);
1351 #ifdef ENABLE_DUMP_IR
1352   auto ms_context = MsContext::GetInstance();
1353   MS_EXCEPTION_IF_NULL(ms_context);
1354   if (ms_context->CanDump(kIntroductory)) {
1355     DumpIR("compile_cache_" + kg->ToString() + ".ir", kg);
1356     for (auto &graph : child_graphs) {
1357       DumpIR("compile_cache_" + graph->ToString() + ".ir", graph);
1358     }
1359   }
1360 #endif
1361   std::vector<AnfNodePtr> temp_nodes;
1362   std::map<KernelGraphPtr, std::vector<AnfNodePtr>> isolated_nodes_map;
1363   HandleParamExistCorrespondFrontendParam(kg);
1364   GetIsolatedNodes(kg, &temp_nodes);
1365   isolated_nodes_map[kg] = temp_nodes;
1366   auto cache_path = context.GetBackendGraphCachePath(fg);
1367   const std::string &mindir_path = cache_path + kMindIrSuffix;
1368   for (const auto &graph : child_graphs) {
1369     temp_nodes.clear();
1370     HandleParamExistCorrespondFrontendParam(graph);
1371     GetIsolatedNodes(graph, &temp_nodes);
1372     isolated_nodes_map[graph] = temp_nodes;
1373   }
1374   std::vector<AnfNodePtr> isolated_nodes;
1375   for (const auto &iter : isolated_nodes_map) {
1376     const auto &nodes = iter.second;
1377     (void)(isolated_nodes.insert(isolated_nodes.end(), nodes.begin(), nodes.end()));
1378   }
1379   std::vector<FuncGraphPtr> child_graphs_for_dump(child_graphs.begin(), child_graphs.end());
1380   if (!DumpBinaryProto(kg, child_graphs_for_dump, isolated_nodes, mindir_path)) {
1381     MS_LOG(ERROR) << "Failed to cache kernel graph to mindir: " << fg->ToString();
1382     return;
1383   }
1384   (void)(std::for_each(front_backend_graph_map_.begin(), front_backend_graph_map_.end(),
1385                        [&context](const auto &fb) { context.AddBackendGraphToFrontendGraph(fb.second, fb.first); }));
1386   const std::string &json_path = cache_path + kJsonSuffix;
1387   if (!DumpKernelGraphJson(kg, child_graphs, isolated_nodes_map, json_path)) {
1388     MS_LOG(ERROR) << "Failed to cache kernel graph to json.";
1389     return;
1390   }
1391   context.Clear();
1392   MS_LOG(INFO) << "Cache kernel graph " << kg->ToString() << " success.";
1393 }
1394 
CreateCallSwitchInputs(const CNodePtr & cnode,KernelGraph * graph) const1395 std::vector<AnfNodePtr> KernelGraphMgr::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) const {
1396   MS_EXCEPTION_IF_NULL(cnode);
1397   MS_EXCEPTION_IF_NULL(graph);
1398   std::vector<AnfNodePtr> cnode_inputs = {
1399     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1400   auto attr_input = cnode->input(kAnfPrimitiveIndex);
1401   MS_EXCEPTION_IF_NULL(attr_input);
1402   auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
1403   MS_EXCEPTION_IF_NULL(cnode_input);
1404   auto switch_cnode = cnode_input->cast<CNodePtr>();
1405   MS_EXCEPTION_IF_NULL(switch_cnode);
1406   if (cnode->size() <= 1) {
1407     cnode_inputs = switch_cnode->inputs();
1408     return cnode_inputs;
1409   }
1410   std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
1411                                            switch_cnode->input(kFirstDataInputIndex)};
1412   for (size_t index = kSwitchTrueBranchIndex; index < switch_cnode->size(); index++) {
1413     auto node = switch_cnode->input(index);
1414     MS_EXCEPTION_IF_NULL(node);
1415     // there is real input in call, should put it to true and false branch in switch
1416     if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
1417       auto partial_node = node->cast<CNodePtr>();
1418       MS_EXCEPTION_IF_NULL(partial_node);
1419       std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
1420       // Put all call args at the end of partial inputs.
1421       for (size_t i = kFirstDataInputIndex; i < cnode->size(); ++i) {
1422         (void)partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(i)));
1423       }
1424       auto new_partial = graph->NewCNode(partial_inputs);
1425       (void)switch_inputs.emplace_back(new_partial);
1426     }
1427   }
1428   if (switch_inputs.size() < kSwitchInputSize) {
1429     MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
1430   }
1431   auto switch_node = graph->NewCNode(switch_inputs);
1432   (void)cnode_inputs.emplace_back(switch_node);
1433   return cnode_inputs;
1434 }
1435 
ProcessNodeRetFunc(const CNodePtr & cnode,KernelGraph * graph,const std::vector<AnfNodePtr> & real_inputs)1436 void KernelGraphMgr::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
1437                                         const std::vector<AnfNodePtr> &real_inputs) {
1438   MS_EXCEPTION_IF_NULL(cnode);
1439   // func1 =switch(branch1, branch2)
1440   // func2 = func1(param1)
1441   // out = func2(param2)
1442   // process the last cnode(func2), not func1 which abstract is AbstractFunction
1443   if (cnode->abstract()->isa<abstract::AbstractFunction>()) {
1444     return;
1445   }
1446   MS_EXCEPTION_IF_NULL(graph);
1447   auto ret = graph->get_return();
1448   MS_EXCEPTION_IF_NULL(ret);
1449   auto return_input = ret->input(kFirstDataInputIndex);
1450   // return node is a function
1451   std::vector<AnfNodePtr> call_inputs = {
1452     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1453   if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
1454     auto return_input_cnode = return_input->cast<CNodePtr>();
1455     MS_EXCEPTION_IF_NULL(return_input_cnode);
1456     auto partial_inputs = return_input_cnode->inputs();
1457     (void)call_inputs.insert(call_inputs.cend(), partial_inputs.cbegin() + kFirstDataInputIndex, partial_inputs.cend());
1458   } else if (IsValueNode<KernelGraph>(return_input)) {  // return node is kernel graph
1459     (void)(call_inputs.emplace_back(return_input));
1460   } else {  // return node is value node
1461     KernelGraphPtr kernel_graph = NewKernelGraph();
1462     MS_EXCEPTION_IF_NULL(kernel_graph);
1463     auto valid_inputs = kernel_graph->MutableValidInputs();
1464     MS_EXCEPTION_IF_NULL(valid_inputs);
1465     auto graph_inputs = kernel_graph->MutableInputs();
1466     MS_EXCEPTION_IF_NULL(graph_inputs);
1467     std::vector<AnfNodePtr> cnode_inputs = {return_input};
1468     for (auto &real_input : real_inputs) {
1469       auto new_parameter = kernel_graph->NewParameter(real_input->abstract());
1470       valid_inputs->push_back(true);
1471       graph_inputs->push_back(new_parameter);
1472       cnode_inputs.push_back(new_parameter);
1473     }
1474     auto new_cnode = kernel_graph->NewCNode(cnode_inputs);
1475     new_cnode->set_abstract(cnode->abstract());
1476     std::vector<AnfNodePtr> return_inputs = {
1477       kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode};
1478     auto return_node = kernel_graph->NewCNode(return_inputs);
1479     return_node->set_abstract(cnode->abstract());
1480     kernel_graph->set_return(return_node);
1481     call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph));
1482   }
1483 
1484   // new call node inputs
1485   for (auto &input_node : real_inputs) {
1486     auto parameter_for_input = CreateNewParameterFromCNode(input_node, graph);
1487     (void)(call_inputs.emplace_back(parameter_for_input));
1488   }
1489 
1490   auto call_node = graph->NewCNode(call_inputs);
1491   MS_EXCEPTION_IF_NULL(call_node);
1492   call_node->set_abstract(cnode->abstract());
1493   // update return input
1494   ret->set_input(kFirstDataInputIndex, call_node);
1495 }
1496 
CreateCallSwitchLayerInputs(const CNodePtr & cnode,KernelGraph * graph)1497 std::vector<AnfNodePtr> KernelGraphMgr::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
1498   MS_EXCEPTION_IF_NULL(cnode);
1499   MS_EXCEPTION_IF_NULL(graph);
1500   std::vector<AnfNodePtr> cnode_inputs = {
1501     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1502   auto attr_input = cnode->input(kAnfPrimitiveIndex);
1503   MS_EXCEPTION_IF_NULL(attr_input);
1504   auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
1505   MS_EXCEPTION_IF_NULL(cnode_input);
1506   auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
1507   MS_EXCEPTION_IF_NULL(switch_layer_cnode);
1508   std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
1509                                                  switch_layer_cnode->input(kFirstDataInputIndex)};
1510   auto make_tuple_node = switch_layer_cnode->input(kSwitchLayerBranchesIndex);
1511   MS_EXCEPTION_IF_NULL(make_tuple_node);
1512   auto node = make_tuple_node->cast<CNodePtr>();
1513   MS_EXCEPTION_IF_NULL(node);
1514   auto make_tuple_inputs = node->inputs();
1515   // there are real inputs in call, should put it to make_tuple in switch_layer
1516   std::vector<AnfNodePtr> real_inputs;
1517   for (size_t idx = kFirstDataInputIndex; idx < cnode->size(); ++idx) {
1518     (void)(real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx))));
1519   }
1520   std::vector<AnfNodePtr> new_make_tuple_inputs = {
1521     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
1522   for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
1523     auto partial_idx = make_tuple_inputs[idx];
1524     MS_EXCEPTION_IF_NULL(cnode->abstract());
1525     std::vector<AnfNodePtr> new_partial_inputs;
1526     KernelGraphPtr partial_kernel_graph;
1527     // switch_layer node input is partial cnode
1528     if (common::AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
1529       auto partial_node = partial_idx->cast<CNodePtr>();
1530       MS_EXCEPTION_IF_NULL(partial_node);
1531       auto partial_input = partial_node->input(kFirstDataInputIndex);
1532       partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
1533       new_partial_inputs = partial_node->inputs();
1534     } else if (IsValueNode<KernelGraph>(partial_idx)) {  // switch_layer node input is kernel graph value node
1535       (void)(new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))));
1536       (void)(new_partial_inputs.emplace_back(partial_idx));
1537       partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx);
1538     }
1539     // when branch in swich_layer return function
1540     MS_EXCEPTION_IF_NULL(partial_kernel_graph);
1541     auto ret = partial_kernel_graph->get_return();
1542     MS_EXCEPTION_IF_NULL(ret);
1543     auto return_input = ret->input(kFirstDataInputIndex);
1544     if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
1545       ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
1546     }
1547     // partial node add input args
1548     (void)new_partial_inputs.insert(new_partial_inputs.cend(), real_inputs.cbegin(), real_inputs.cend());
1549     // create new partial node
1550     auto new_partial = graph->NewCNode(new_partial_inputs);
1551     (void)(new_make_tuple_inputs.emplace_back(new_partial));
1552   }
1553   auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
1554   auto abstract = make_tuple_node->abstract();
1555   if (abstract == nullptr) {
1556     abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
1557   }
1558   new_make_tuple->set_abstract(abstract);
1559   (void)(switch_layer_inputs.emplace_back(new_make_tuple));
1560   auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
1561   (void)(cnode_inputs.emplace_back(new_switch_layer));
1562   return cnode_inputs;
1563 }
1564 
CreateSwitchOrPartialNode(const CNodePtr & cnode,KernelGraph * graph)1565 std::vector<AnfNodePtr> KernelGraphMgr::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
1566   MS_EXCEPTION_IF_NULL(cnode);
1567   MS_EXCEPTION_IF_NULL(graph);
1568   // create primitive of cnode:call(partial or switch or switch_layer)
1569   std::vector<AnfNodePtr> cnode_inputs = {
1570     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1571   auto attr_input = cnode->input(kAnfPrimitiveIndex);
1572   MS_EXCEPTION_IF_NULL(attr_input);
1573   auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
1574   if (cnode_input == nullptr) {
1575     MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created.";
1576     return {};
1577   }
1578   // if the node is partial, insert the inputs of partial to the call
1579   if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
1580     auto partial_node = attr_input->cast<CNodePtr>();
1581     MS_EXCEPTION_IF_NULL(partial_node);
1582     auto partial_inputs = partial_node->inputs();
1583     (void)std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
1584                          std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
1585                            MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
1586                            return graph->GetBackendAnfByFrontAnf(node);
1587                          });
1588     return cnode_inputs;
1589   } else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
1590     return CreateCallSwitchInputs(cnode, graph);
1591   } else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
1592     return CreateCallSwitchLayerInputs(cnode, graph);
1593   } else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
1594     // only support tuple get item from a call subgraph output
1595     auto tuple_get_node = cnode_input->cast<CNodePtr>();
1596     MS_EXCEPTION_IF_NULL(tuple_get_node);
1597     auto get_from_node = tuple_get_node->input(kFirstIndex);
1598     MS_EXCEPTION_IF_NULL(get_from_node);
1599     if (common::AnfAlgo::CheckPrimitiveType(get_from_node, prim::kPrimCall)) {
1600       auto call_node = get_from_node->cast<CNodePtr>();
1601       MS_EXCEPTION_IF_NULL(call_node);
1602       auto call_graph = call_node->input(kFirstIndex);
1603       auto sub_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(call_graph);
1604       MS_EXCEPTION_IF_NULL(sub_kernel_graph);
1605       if (kernel_graph_partial_map_.find(sub_kernel_graph.get()) == kernel_graph_partial_map_.end()) {
1606         MS_LOG(EXCEPTION) << "Kernel Graph: " << sub_kernel_graph->ToString()
1607                           << " has not a return value is a Partial Func.";
1608       }
1609       auto tuple_get_idx = common::AnfAlgo::GetTupleGetItemOutIndex(tuple_get_node);
1610       auto info = kernel_graph_partial_map_[sub_kernel_graph.get()];
1611       call_node->set_abstract(info.abstract);
1612       (void)cnode_inputs.emplace_back(info.sub_graph);
1613       auto context = MsContext::GetInstance();
1614       MS_EXCEPTION_IF_NULL(context);
1615       if (context->CellReuseLevel() == CellReuseLevel::kLazyInline) {
1616         // call_graph and info.sub_graph need inline when cell reuse.
1617         sub_kernel_graph->set_need_inline(true);
1618         auto partial_sub_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(info.sub_graph);
1619         MS_EXCEPTION_IF_NULL(partial_sub_graph);
1620         partial_sub_graph->set_need_inline(true);
1621         MS_LOG(INFO) << "Inline graph " << sub_kernel_graph->graph_id() << " and graph "
1622                      << partial_sub_graph->graph_id();
1623       }
1624       MS_LOG(INFO) << "Use cell reuse: " << sub_kernel_graph->graph_id();
1625       if (info.param_begin != tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)) {
1626         MS_LOG(EXCEPTION) << "Call param is not a graph, the TupleGetItem index: " << tuple_get_idx
1627                           << ", the partial graph index: " << info.param_begin
1628                           << ", need idx: " << tuple_get_idx + std::max(static_cast<int>(info.multi_tuple) - 1, 0)
1629                           << ", call graph: " << call_graph->fullname_with_scope();
1630       }
1631       for (size_t i = info.param_begin; i < info.param_end; i++) {
1632         auto idx = NewValueNode(SizeToLong(i));
1633         MS_EXCEPTION_IF_NULL(idx);
1634         auto imm = std::make_shared<Int64Imm>(i);
1635         idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
1636         auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), call_node, idx});
1637         std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(call_node, i)};
1638         auto shapes = {common::AnfAlgo::GetOutputInferShape(call_node, i)};
1639         common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
1640         (void)cnode_inputs.emplace_back(getitem);
1641       }
1642       return cnode_inputs;
1643     }
1644   }
1645   MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
1646                 << "must be partial or switch or switch_layer.";
1647   return {};
1648 }
1649 
CreateValueNode(const CNodePtr & cnode,KernelGraph * graph)1650 std::vector<AnfNodePtr> KernelGraphMgr::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
1651   MS_EXCEPTION_IF_NULL(cnode);
1652   MS_EXCEPTION_IF_NULL(graph);
1653   std::vector<AnfNodePtr> cnode_inputs;
1654   auto attr_input = cnode->input(kAnfPrimitiveIndex);
1655   MS_EXCEPTION_IF_NULL(attr_input);
1656   if (common::AnfAlgo::IsGraphKernel(cnode)) {
1657     auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
1658     MS_EXCEPTION_IF_NULL(fg);
1659     auto new_fg = BasicClone(fg);
1660     cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
1661   } else {
1662     // create primitive of cnode:call
1663     cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
1664     // create a ValueNode<KernelGraph> as input of cnode:call
1665     if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
1666       (void)(cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)));
1667     } else {
1668       auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
1669       if (new_value_node != nullptr) {
1670         (void)(cnode_inputs.emplace_back(new_value_node));
1671       }
1672     }
1673   }
1674   return cnode_inputs;
1675 }
1676 
CreateCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs)1677 void KernelGraphMgr::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
1678                                        std::vector<AnfNodePtr> *cnode_inputs) {
1679   MS_EXCEPTION_IF_NULL(cnode);
1680   MS_EXCEPTION_IF_NULL(graph);
1681   MS_EXCEPTION_IF_NULL(cnode_inputs);
1682   if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1683     (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
1684     for (size_t index = kSwitchTrueBranchIndex; index < cnode->size(); index++) {
1685       auto node_input = cnode->input(index);
1686       auto switch_input = CreateSwitchInput(cnode, node_input, graph);
1687       (void)cnode_inputs->emplace_back(switch_input);
1688     }
1689   } else {
1690     for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->size(); input_idx++) {
1691       auto anf = cnode->input(input_idx);
1692       MS_EXCEPTION_IF_NULL(anf);
1693       // anf has been created before
1694       if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
1695         (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
1696         continue;
1697       } else if (anf->isa<Parameter>()) {
1698         auto new_parameter = CreateNewParameterFromParameter(anf, graph);
1699         MS_EXCEPTION_IF_NULL(new_parameter);
1700         (void)cnode_inputs->emplace_back(new_parameter);
1701         graph->FrontBackendMapAdd(anf, new_parameter);
1702         continue;
1703       } else if (anf->isa<ValueNode>()) {
1704         auto new_value_node = CreateNewValueNode(anf, graph);
1705         MS_EXCEPTION_IF_NULL(new_value_node);
1706         (void)cnode_inputs->emplace_back(new_value_node);
1707         continue;
1708       }
1709       MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
1710     }
1711   }
1712 }
1713 
CreateValueNodeKernelGraph(const AnfNodePtr & anf,KernelGraph * graph)1714 ValueNodePtr KernelGraphMgr::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
1715   MS_EXCEPTION_IF_NULL(anf);
1716   MS_EXCEPTION_IF_NULL(graph);
1717   auto value_node = anf->cast<ValueNodePtr>();
1718   MS_EXCEPTION_IF_NULL(value_node);
1719   auto sub_func_graph = common::AnfAlgo::GetValueNodeFuncGraph(anf);
1720   MS_EXCEPTION_IF_NULL(sub_func_graph);
1721   if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) {
1722     MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
1723   }
1724   auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph.get()];
1725 
1726   ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
1727   MS_EXCEPTION_IF_NULL(new_value_node);
1728   new_value_node->set_abstract(value_node->abstract());
1729   // create new kernel_info of new value_node
1730   auto kernel_info = std::make_shared<device::KernelInfo>();
1731   MS_EXCEPTION_IF_NULL(kernel_info);
1732   new_value_node->set_kernel_info(kernel_info);
1733   // create kernel_build_info for new value node
1734   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
1735   MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
1736   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
1737   AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
1738 
1739   graph->FrontBackendMapAdd(anf, new_value_node);
1740 
1741   return new_value_node;
1742 }
1743 
CreateNewParameter(const AnfNodePtr & anf,KernelGraph * graph) const1744 ParameterPtr KernelGraphMgr::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) const {
1745   MS_EXCEPTION_IF_NULL(anf);
1746   MS_EXCEPTION_IF_NULL(graph);
1747   if (!anf->isa<Parameter>()) {
1748     MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
1749   }
1750 
1751   auto param_value = GetParamDefaultValue(anf);
1752   ParameterPtr new_parameter = nullptr;
1753   // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
1754   if (param_value != nullptr) {
1755     new_parameter = param_value->parameter();
1756     if (new_parameter == nullptr) {
1757       TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1758       new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1759       param_value->set_parameter(new_parameter);
1760     }
1761   } else {
1762     TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1763     new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1764   }
1765 
1766   new_parameter->IncreaseUsedGraphCount();
1767 
1768   return new_parameter;
1769 }
1770 
FlattenTuple(const CNodePtr & node)1771 void KernelGraphMgr::FlattenTuple(const CNodePtr &node) {
1772   MS_EXCEPTION_IF_NULL(node);
1773   if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
1774     auto call_graph = node->input(kFirstIndex);
1775     auto sub_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(call_graph);
1776     MS_EXCEPTION_IF_NULL(sub_kernel_graph);
1777     auto iter = kernel_graph_partial_map_.find(sub_kernel_graph.get());
1778     if (iter != kernel_graph_partial_map_.end() && iter->second.multi_tuple != 0) {
1779       (void)need_flatten_.insert(node);
1780     }
1781   } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
1782     auto input = node->input(kFirstIndex);
1783     auto get_idx = common::AnfAlgo::GetTupleGetItemOutIndex(node);
1784     if (need_flatten_.find(input) != need_flatten_.end() && get_idx == 0) {
1785       need_flatten_tuple_map_[node] = input;
1786     }
1787   }
1788   for (size_t i = 0; i < common::AnfAlgo::GetInputNum(node); i++) {
1789     auto input = common::AnfAlgo::GetInputNode(node, i);
1790     auto iter = need_flatten_tuple_map_.find(input);
1791     if (iter != need_flatten_tuple_map_.end()) {
1792       node->set_input(i + 1, iter->second);
1793     }
1794   }
1795 }
1796 
CreateCNodeOfKernelGraph(const AnfNodePtr & node,KernelGraph * graph)1797 bool KernelGraphMgr::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
1798   MS_EXCEPTION_IF_NULL(node);
1799   MS_EXCEPTION_IF_NULL(graph);
1800   auto cnode = node->cast<CNodePtr>();
1801   MS_EXCEPTION_IF_NULL(cnode);
1802   // create a new cnode object
1803   auto new_cnode = CreateNewCNode(cnode, graph);
1804   if (new_cnode == nullptr) {
1805     return false;
1806   }
1807   new_cnode->set_abstract(cnode->abstract());
1808   std::string fullname = cnode->fullname_with_scope();
1809   auto prim_input = cnode->input(kAnfPrimitiveIndex);
1810   // cnode is a call (partial/switch/switch_layer), full scope name is "1_2".
1811   // it is hard to analysis bug when it used as ge node name.
1812   if (!prim_input->isa<CNode>()) {
1813     new_cnode->set_fullname_with_scope(fullname);
1814   }
1815   new_cnode->set_scope(cnode->scope());
1816   if (!graph->is_dynamic_shape() && common::AnfAlgo::IsDynamicShape(new_cnode)) {
1817     graph->SetGraphDynamicAttr(true);
1818   }
1819   graph->FrontBackendMapAdd(node, new_cnode);
1820   SetReturnNode(new_cnode, graph);
1821   FlattenTuple(new_cnode);
1822   return true;
1823 }
1824 
AddParameterToGraphInputs(const std::vector<AnfNodePtr> & parameters,KernelGraph * graph) const1825 void KernelGraphMgr::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) const {
1826   MS_EXCEPTION_IF_NULL(graph);
1827   auto graph_inputs = graph->MutableInputs();
1828   MS_EXCEPTION_IF_NULL(graph_inputs);
1829   graph_inputs->clear();
1830   for (auto &parameter : parameters) {
1831     MS_EXCEPTION_IF_NULL(parameter);
1832     auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
1833     if (backend_parameter == nullptr) {
1834       // for example "def f(x,y,z) {return x + y}", parameter z in unused
1835       auto new_parameter = CreateNewParameter(parameter, graph);
1836       graph_inputs->push_back(new_parameter);
1837       graph->FrontBackendMapAdd(parameter, new_parameter);
1838       MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
1839       continue;
1840     }
1841     graph_inputs->push_back(backend_parameter);
1842   }
1843 }
1844 
1845 // 1. Convert the node to make_tuple if the node is a ValueNode<ValueTuple> and it's the input of 'return' node.
1846 // 2. Set the return of graph if node is "Return" node.
1847 // 3. If the return of graph has a Partial Func, should inline it in return value.
SetReturnNode(const AnfNodePtr & node,KernelGraph * graph)1848 void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
1849   MS_EXCEPTION_IF_NULL(graph);
1850   MS_EXCEPTION_IF_NULL(node);
1851   if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
1852     return;
1853   }
1854   constexpr auto kReturnInputIdx = 1;
1855   auto return_node = node->cast<CNodePtr>();
1856   MS_EXCEPTION_IF_NULL(return_node);
1857   graph->set_return(return_node);
1858   auto graph_output = return_node->input(kReturnInputIdx);
1859   MS_EXCEPTION_IF_NULL(graph_output);
1860 
1861   // If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
1862   // match this pattern because that pass begin with output node but return node. So we add transform value tuple
1863   // to make_tuple here.
1864   if (common::AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
1865     return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
1866   }
1867 
1868   // inline partial to call graph
1869   auto return_tuple = return_node->input(kReturnInputIdx);
1870   MS_EXCEPTION_IF_NULL(return_tuple);
1871   if (return_tuple->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(return_tuple, prim::kPrimMakeTuple)) {
1872     auto make_tuple = return_tuple->cast<CNodePtr>();
1873     MS_EXCEPTION_IF_NULL(make_tuple);
1874     size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
1875     // only support the last return node is a partial func now
1876     auto last_input_node = common::AnfAlgo::GetInputNode(make_tuple, tuple_input_num - 1);
1877     MS_EXCEPTION_IF_NULL(last_input_node);
1878     if (last_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(last_input_node, prim::kPrimPartial)) {
1879       size_t multi_tuple = 0;
1880       auto partial_node = last_input_node->cast<CNodePtr>();
1881       MS_EXCEPTION_IF_NULL(partial_node);
1882       size_t partial_input_num = common::AnfAlgo::GetInputTensorNum(partial_node);
1883       std::vector<AnfNodePtr> make_tuple_inputs;
1884       (void)make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1885       // skip last return node (is a partial)
1886       size_t param_begin = 0;
1887       for (size_t i = 0; i < tuple_input_num - 1; i++) {
1888         auto input = common::AnfAlgo::GetInputNode(make_tuple, i);
1889         MS_EXCEPTION_IF_NULL(input);
1890         auto node_abs = input->abstract();
1891         MS_EXCEPTION_IF_NULL(node_abs);
1892         if (node_abs->isa<abstract::AbstractSequence>()) {
1893           MS_EXCEPTION_IF_CHECK_FAIL(
1894             i == 0, "Input index: " + std::to_string(i) + " is a make tuple, input node: " + input->DebugString());
1895           MS_LOG(DEBUG) << "Flatten the make tuple, input node: " << input->DebugString()
1896                         << ", output num: " << AnfUtils::GetOutputTensorNum(input);
1897           // flatten the make tuple
1898           for (size_t j = 0; j < AnfUtils::GetOutputTensorNum(input); j++) {
1899             auto idx = NewValueNode(SizeToLong(j));
1900             MS_EXCEPTION_IF_NULL(idx);
1901             auto imm = std::make_shared<Int64Imm>(j);
1902             idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
1903             auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
1904             std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(input, j)};
1905             auto shapes = {common::AnfAlgo::GetOutputInferShape(input, j)};
1906             common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
1907             param_begin++;
1908             multi_tuple++;
1909             (void)make_tuple_inputs.emplace_back(getitem);
1910           }
1911         } else {
1912           param_begin++;
1913           (void)make_tuple_inputs.emplace_back(input);
1914         }
1915       }
1916       // skip partial graph
1917       for (size_t i = kFirstIndex; i < partial_input_num; i++) {
1918         (void)make_tuple_inputs.emplace_back(common::AnfAlgo::GetInputNode(partial_node, i));
1919       }
1920       auto g_output = graph->NewCNode(make_tuple_inputs);
1921       MS_EXCEPTION_IF_NULL(g_output);
1922       std::vector<AbstractBasePtr> abstract_list;
1923       for (size_t i = kFirstIndex; i < make_tuple_inputs.size(); ++i) {
1924         auto inputs_node = make_tuple_inputs[i];
1925         MS_EXCEPTION_IF_NULL(inputs_node);
1926         (void)abstract_list.emplace_back(inputs_node->abstract());
1927       }
1928       auto abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
1929       MS_EXCEPTION_IF_NULL(g_output);
1930       g_output->set_abstract(abstract);
1931       graph->set_output(g_output);
1932       kernel_graph_partial_map_[graph] = {abstract, common::AnfAlgo::GetInputNode(partial_node, 0), param_begin,
1933                                           common::AnfAlgo::GetInputTensorNum(g_output), multi_tuple};
1934     }
1935   }
1936 }
1937 
ConstructKernelGraph(const AnfNodePtrList & lst,const AnfNodePtrList & outputs,DeviceType device_target,bool common_opt,bool is_enable_zero_copy)1938 KernelGraphPtr KernelGraphMgr::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
1939                                                     DeviceType device_target, bool common_opt,
1940                                                     bool is_enable_zero_copy) {
1941   mindspore::HashMap<AnfNodePtr, AnfNodePtr> other_graph_cnode;
1942   std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
1943   auto graph = NewKernelGraph();
1944   MS_EXCEPTION_IF_NULL(graph);
1945   // Set the zero copy flag in subgraph sink mode.
1946   if (is_enable_zero_copy) {
1947     MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString();
1948     graph->set_flag(kFlagEnableZeroCopyInGraph, true);
1949   }
1950   MS_LOG(INFO) << "Create graph: " << graph->graph_id();
1951   graph->set_device_target(device_target);
1952   for (const auto &node : lst) {
1953     MS_EXCEPTION_IF_NULL(node);
1954     MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
1955     if (!node->isa<CNode>()) {
1956       MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
1957     }
1958     auto cnode = node->cast<CNodePtr>();
1959     MS_EXCEPTION_IF_NULL(cnode);
1960     // create a new cnode object
1961     auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
1962     MS_EXCEPTION_IF_NULL(new_cnode);
1963     if (IsOneOfPrimitiveCNode(new_cnode, {prim::kPrimCall, prim::kPrimPartial})) {
1964       auto fn = new_cnode->input(kIndexOne);
1965       MS_EXCEPTION_IF_NULL(fn);
1966       auto child_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(fn);
1967       MS_EXCEPTION_IF_NULL(child_kernel_graph);
1968       child_graph_order.push_back(std::weak_ptr<KernelGraph>(child_kernel_graph));
1969     }
1970 
1971     new_cnode->set_abstract(cnode->abstract());
1972     new_cnode->set_scope(cnode->scope());
1973     new_cnode->set_attrs(cnode->attrs());
1974 
1975     if (new_cnode->HasAttr(kAttrReplaceRealKernelInBackend)) {
1976       MS_LOG(DEBUG) << "Erase flag for node: " << new_cnode->DebugString();
1977       new_cnode->EraseAttr(kAttrReplaceRealKernelInBackend);
1978     }
1979 
1980     if (cnode->user_data<pynative::JitCallGraph>()) {
1981       new_cnode->set_user_data(cnode->user_data<pynative::JitCallGraph>());
1982     }
1983     // record map relations between anf from ME and new anf node used in backend
1984     graph->FrontBackendMapAdd(node, new_cnode);
1985     if (!graph->is_dynamic_shape() && common::AnfAlgo::IsDynamicShape(new_cnode)) {
1986       graph->SetGraphDynamicAttr(true);
1987     }
1988   }
1989   // add a make_tuple at the end of graph as output
1990   graph->set_child_graph_order(child_graph_order);
1991   graph->set_output(ConstructOutput(outputs, graph));
1992   graph->SetExecOrderByDefault();
1993 
1994 #ifndef ENABLE_SECURITY
1995   if (ExistSummaryNode(graph.get())) {
1996     graph->set_summary_node_exist(true);
1997   }
1998 #endif
1999   MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
2000   if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
2001     FuncGraphManagerPtr manager = MakeManager({graph}, false);
2002     if (manager) {
2003       manager->AddFuncGraph(graph);
2004       graph->set_manager(manager);
2005     }
2006     UnifyMindIR(graph);
2007     graph->UpdateGraphAquireGilAttr();
2008     if (common_opt) {
2009       opt::BackendCommonOptimization(graph);
2010     }
2011     graph->SetInputNodes();
2012     SetInputNodeUsage(graph, manager);
2013     graph->SetOptimizerFlag();
2014   }
2015   graph->set_parameters(graph->inputs());
2016   return graph;
2017 }
2018 
ConstructKernelGraph(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph,DeviceType device_target)2019 std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGraphPtr &func_graph,
2020                                                                   std::vector<KernelGraphPtr> *all_out_graph,
2021                                                                   DeviceType device_target) {
2022   auto graph = NewKernelGraph();
2023   front_backend_graph_map_[func_graph.get()] = graph;
2024   ConstructKernelGraphInner(func_graph, all_out_graph, device_target, graph);
2025   return graph;
2026 }
2027 
ConstructPackKernelGraph(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph,DeviceType device_target)2028 std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructPackKernelGraph(const FuncGraphPtr &func_graph,
2029                                                                       std::vector<KernelGraphPtr> *all_out_graph,
2030                                                                       DeviceType device_target) {
2031   auto graph = NewPynativeKernelGraph();
2032   ConstructKernelGraphInner(func_graph, all_out_graph, device_target, graph);
2033   return graph;
2034 }
2035 
ConstructKernelGraphInner(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph,DeviceType device_target,const KernelGraphPtr & graph)2036 void KernelGraphMgr::ConstructKernelGraphInner(const FuncGraphPtr &func_graph,
2037                                                std::vector<KernelGraphPtr> *all_out_graph, DeviceType device_target,
2038                                                const KernelGraphPtr &graph) {
2039   MS_EXCEPTION_IF_NULL(func_graph);
2040   MS_EXCEPTION_IF_NULL(all_out_graph);
2041   auto node_list = TopoSort(func_graph->get_return());
2042   MS_EXCEPTION_IF_NULL(graph);
2043   auto context = MsContext::GetInstance();
2044   MS_EXCEPTION_IF_NULL(context);
2045   if (func_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline) {
2046     MS_LOG(INFO) << "Need backend inline: " << graph->graph_id();
2047     graph->set_need_inline(true);
2048   }
2049   MS_LOG(INFO) << "Create graph: " << graph->graph_id();
2050   graph->set_device_target(device_target);
2051   // Create parameter
2052   for (const auto &node : func_graph->parameters()) {
2053     MS_EXCEPTION_IF_NULL(node);
2054     MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
2055     auto graph_inputs = graph->MutableInputs();
2056     MS_EXCEPTION_IF_NULL(graph_inputs);
2057     auto new_parameter = CreateNewParameter(node, graph.get());
2058     graph_inputs->push_back(new_parameter);
2059     graph->FrontBackendMapAdd(node, new_parameter);
2060   }
2061 
2062   std::vector<ParameterPtr> added_parameters;
2063   std::vector<std::weak_ptr<KernelGraph>> child_kernel_graphs;
2064   for (const auto &node : node_list) {
2065     MS_EXCEPTION_IF_NULL(node);
2066     if (node->isa<Parameter>()) {
2067       continue;
2068     }
2069     MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
2070     // Create value node
2071     if (node->isa<ValueNode>()) {
2072       if (NeedConvertValueNodeToParameter(node)) {
2073         ConvertValueNodeToParameter(graph, node, &added_parameters);
2074         continue;
2075       }
2076       // Create common value node
2077       if (!IsValueNode<FuncGraph>(node)) {
2078         (void)CreateNewValueNode(node, graph.get());
2079         continue;
2080       }
2081       // Create child kernel graph according ValueNode<FuncGraph>
2082       FuncGraphPtr child_graph = common::AnfAlgo::GetValueNodeFuncGraph(node);
2083       auto child_kernel_graph = front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()
2084                                   ? ConstructKernelGraph(child_graph, all_out_graph, device_target)
2085                                   : front_backend_graph_map_[child_graph.get()];
2086       (void)child_kernel_graphs.emplace_back(std::weak_ptr<KernelGraph>(child_kernel_graph));
2087       (void)CreateValueNodeKernelGraph(node, graph.get());
2088       continue;
2089     }
2090     // Create cnode
2091     if (!CreateCNodeOfKernelGraph(node, graph.get())) {
2092 #ifdef ENABLE_DUMP_IR
2093       DumpIR("construct_kernel_graph_fail.ir", func_graph);
2094 #endif
2095       MS_LOG(EXCEPTION) << "Construct func graph " << func_graph->ToString() << " failed."
2096                         << trace::DumpSourceLines(node);
2097     }
2098   }
2099 
2100   AddParameterToGraphInputs(func_graph->parameters(), graph.get());
2101   // Add ValueNode-Parameter to graph.
2102   auto graph_inputs = graph->MutableInputs();
2103   MS_EXCEPTION_IF_NULL(graph_inputs);
2104   for (auto &parameter : added_parameters) {
2105     (void)graph_inputs->emplace_back(parameter);
2106   }
2107 
2108   FuncGraphManagerPtr manager = MakeManager({graph});
2109   graph->SetInputNodes();
2110   SetInputNodeUsage(graph, manager);
2111   graph->SetExecOrderByDefault();
2112 
2113 #ifndef ENABLE_SECURITY
2114   if (ExistSummaryNode(graph.get())) {
2115     graph->set_summary_node_exist(true);
2116   }
2117 #endif
2118 
2119   all_out_graph->push_back(graph);
2120   graph->set_parameters(graph->inputs());
2121   graph->set_child_graph_order(child_kernel_graphs);
2122 }
2123 
HandleGraphInputsOutputs(const nlohmann::json & graph_json,KernelGraph * graph)2124 void HandleGraphInputsOutputs(const nlohmann::json &graph_json, KernelGraph *graph) {
2125   MS_EXCEPTION_IF_NULL(graph);
2126   auto &context = CompileCacheContext::GetInstance();
2127   if (graph_json.contains(kInputs)) {
2128     const auto &inputs_json = graph_json[kInputs];
2129     auto mutable_inputs = graph->MutableInputs();
2130     for (const auto &name : inputs_json) {
2131       AnfNodePtr node = context.FindBackNodeByBackName(name);
2132       MS_EXCEPTION_IF_NULL(node);
2133       context.InsertBackNameToBackNode(name, node);
2134       mutable_inputs->push_back(node);
2135     }
2136   }
2137   if (graph_json.contains(kParameters)) {
2138     const auto &parameters_json = graph_json[kParameters];
2139     std::vector<AnfNodePtr> parameters;
2140     for (const auto &name : parameters_json) {
2141       auto node = context.FindBackNodeByBackName(name);
2142       MS_EXCEPTION_IF_NULL(node);
2143       parameters.push_back(node);
2144     }
2145     graph->set_parameters(parameters);
2146   }
2147 
2148   if (graph_json.contains(kValidInputs)) {
2149     const auto &valid_inputs_json = graph_json[kValidInputs];
2150     auto mutable_valid_inputs = graph->MutableValidInputs();
2151     std::vector<bool> valid_inputs;
2152     (void)(std::transform(valid_inputs_json.begin(), valid_inputs_json.end(), std::back_inserter(valid_inputs),
2153                           [](const auto &val) { return val; }));
2154     *mutable_valid_inputs = valid_inputs;
2155   }
2156   const auto &front_graph_name = graph_json[kCorrespondFrontendGraph];
2157   const auto &front_graph_node = context.FindFrontNodeByFrontName(front_graph_name);
2158   FuncGraphPtr front_graph = GetValueNode<FuncGraphPtr>(front_graph_node);
2159   MS_EXCEPTION_IF_NULL(front_graph);
2160   graph->FrontBackendMapAdd(front_graph->get_return(), graph->get_return());
2161 }
2162 
HandleGraphSimpleAttr(const nlohmann::json & graph_json,KernelGraph * graph)2163 void HandleGraphSimpleAttr(const nlohmann::json &graph_json, KernelGraph *graph) {
2164   MS_EXCEPTION_IF_NULL(graph);
2165   MS_LOG(INFO) << "Handle graph " << graph->ToString() << " simple attr.";
2166   auto &context = CompileCacheContext::GetInstance();
2167   graph->set_run_mode(graph_json[kRunMode]);
2168   graph->set_is_loop_count_sink(graph_json[kIsLoopCountSink]);
2169   graph->SetGraphDynamicAttr(graph_json[kIsDynamicShape]);
2170   graph->set_device_target(graph_json[kDeviceTarget]);
2171   graph->set_root_graph_id(graph_json[kRootGraphId]);
2172   graph->set_executable(graph_json[kExecutable]);
2173   graph->set_recursive_call(graph_json[kHasRecursiveCall]);
2174   graph->set_need_inline(graph_json[kNeedInline]);
2175   graph->set_is_need_gil(graph_json[kIsNeedGil]);
2176   graph->set_is_from_single_op(graph_json[kIsFromSingleOp]);
2177   graph->set_subgraph_multi_call(graph_json[kHasSubgraphMultiCall]);
2178   graph->set_label_num(graph_json[kLabelNum]);
2179 #ifndef ENABLE_SECURITY
2180   // set summary_node of graph
2181   graph->set_summary_node_exist(graph_json[kSummaryNodeExist]);
2182 #endif
2183   if (graph_json.contains(kStartLabel)) {
2184     auto start_label = context.FindBackNodeByBackName(graph_json[kStartLabel]);
2185     if (start_label) {
2186       auto cstart_label = start_label->cast<CNodePtr>();
2187       MS_EXCEPTION_IF_NULL(cstart_label);
2188       graph->set_start_label(cstart_label);
2189     }
2190   }
2191   if (graph_json.contains(kEndGoto)) {
2192     auto end_goto = context.FindBackNodeByBackName(graph_json[kEndGoto]);
2193     if (end_goto) {
2194       auto cend_goto = end_goto->cast<CNodePtr>();
2195       MS_EXCEPTION_IF_NULL(cend_goto);
2196       graph->set_end_goto(cend_goto);
2197     }
2198   }
2199   if (graph_json.contains(kParameterUniqueNameToName)) {
2200     const auto &unique_name_to_name_json = graph_json[kParameterUniqueNameToName];
2201     for (const auto &[unique_name, name] : unique_name_to_name_json.items()) {
2202       auto node = context.FindBackNodeByBackName(unique_name);
2203       MS_EXCEPTION_IF_NULL(node);
2204       auto param = node->cast<ParameterPtr>();
2205       MS_EXCEPTION_IF_NULL(param);
2206       param->set_name(name);
2207     }
2208   }
2209   MS_LOG(INFO) << "Handle graph " << graph->ToString() << " simple attr success.";
2210 }
2211 
HandleAttrAboutOtherGraph(const mindspore::HashMap<GraphId,std::shared_ptr<KernelGraph>> & graphs,const nlohmann::json & graph_json,KernelGraph * graph)2212 void HandleAttrAboutOtherGraph(const mindspore::HashMap<GraphId, std::shared_ptr<KernelGraph>> &graphs,
2213                                const nlohmann::json &graph_json, KernelGraph *graph) {
2214   MS_EXCEPTION_IF_NULL(graph);
2215   auto &context = CompileCacheContext::GetInstance();
2216   if (graph_json.contains(kPreGraphs)) {
2217     const auto &pre_graphs_json = graph_json[kPreGraphs];
2218     for (const auto &iter : pre_graphs_json) {
2219       auto pre_graph = graphs.at(iter);
2220       MS_EXCEPTION_IF_NULL(pre_graph);
2221       graph->AddPreGraph(pre_graph);
2222     }
2223   }
2224   if (graph_json.contains(kPostGraphs)) {
2225     const auto &post_graphs_json = graph_json[kPostGraphs];
2226     for (const auto &iter : post_graphs_json) {
2227       auto post_graph = graphs.at(iter);
2228       MS_EXCEPTION_IF_NULL(post_graph);
2229       graph->AddPostGraph(post_graph);
2230     }
2231   }
2232   if (graph_json.contains(kChildGraphResult)) {
2233     const auto &child_graph_result_json = graph_json[kChildGraphResult];
2234     for (const auto &iter : child_graph_result_json) {
2235       auto node = context.FindBackNodeByBackName(iter);
2236       MS_EXCEPTION_IF_NULL(node);
2237       graph->AddChildGraphResult(node);
2238     }
2239   }
2240   if (graph_json.contains(kChildGraphOrder)) {
2241     const auto &child_graph_order_json = graph_json[kChildGraphOrder];
2242     std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
2243     for (const auto &iter : child_graph_order_json) {
2244       auto child_graph = graphs.at(iter);
2245       MS_EXCEPTION_IF_NULL(child_graph);
2246       child_graph_order.push_back(std::weak_ptr<KernelGraph>(child_graph));
2247     }
2248     graph->set_child_graph_order(child_graph_order);
2249   }
2250 }
2251 
HandleGraphComplexAttr(const mindspore::HashMap<GraphId,std::shared_ptr<KernelGraph>> & graphs,const nlohmann::json & graph_json,KernelGraph * graph)2252 void HandleGraphComplexAttr(const mindspore::HashMap<GraphId, std::shared_ptr<KernelGraph>> &graphs,
2253                             const nlohmann::json &graph_json, KernelGraph *graph) {
2254   MS_EXCEPTION_IF_NULL(graph);
2255   MS_LOG(INFO) << "Handle graph " << graph->ToString() << " complex attr.";
2256   auto &context = CompileCacheContext::GetInstance();
2257   std::vector<CNodePtr> execution_order;
2258   const auto &execution_order_json = graph_json[kExecutionOrder];
2259   for (const auto &order : execution_order_json) {
2260     auto node = context.FindBackNodeByBackName(order);
2261     MS_EXCEPTION_IF_NULL(node);
2262     execution_order.push_back(node->cast<CNodePtr>());
2263   }
2264   graph->set_execution_order(execution_order);
2265   if (graph_json.contains(kCommSubGraphIds)) {
2266     const auto &comm_sub_grpah_ids_json = graph_json[kCommSubGraphIds];
2267     for (const auto &iter : comm_sub_grpah_ids_json) {
2268       graph->RecordNewCommSubGraphId(iter);
2269     }
2270   }
2271   HandleAttrAboutOtherGraph(graphs, graph_json, graph);
2272   if (graph_json.contains(kInternalParameterToFrontNode)) {
2273     const auto &internal_parameter_to_front_node_json = graph_json[kInternalParameterToFrontNode];
2274     HashMap<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node;
2275     for (const auto &iter : internal_parameter_to_front_node_json) {
2276       const auto &back_name = iter.at(0);
2277       const auto &front_name = iter.at(kIndexOne);
2278       const auto &index = iter.at(kIndexTwo);
2279       auto back_node = context.FindBackNodeByBackName(back_name);
2280       MS_EXCEPTION_IF_NULL(back_node);
2281       auto front_node = context.FindFrontNodeByFrontName(front_name);
2282       MS_EXCEPTION_IF_NULL(front_node);
2283       internal_parameter_to_front_node[back_node] = AnfWithOutIndex(front_node, index);
2284     }
2285     graph->SetInternalParameterToFrontNodeMap(internal_parameter_to_front_node);
2286   }
2287   if (graph_json.contains(kRefInOutMap)) {
2288     const auto &ref_in_out_map_json = graph_json[kRefInOutMap];
2289     for (const auto &iter : ref_in_out_map_json) {
2290       const auto &first_name = iter.at(0);
2291       const auto &first_index = iter.at(kIndexOne);
2292       const auto &second_name = iter.at(kIndexTwo);
2293       const auto &second_index = iter.at(kIndexThree);
2294       auto first_node = context.FindBackNodeByBackName(first_name);
2295       MS_EXCEPTION_IF_NULL(first_node);
2296       auto second_node = context.FindBackNodeByBackName(second_name);
2297       MS_EXCEPTION_IF_NULL(second_node);
2298       graph->AddRefCorrespondPairs(AnfWithOutIndex(first_node, first_index),
2299                                    AnfWithOutIndex(second_node, second_index));
2300     }
2301   }
2302   if (graph_json.contains(kNodesKernelInfo)) {
2303     const auto &kernel_infos_json = graph_json[kNodesKernelInfo];
2304     LoadAnfKernelInfoFromJson(kernel_infos_json);
2305   }
2306   if (graph_json.contains(kGraphValueNodes)) {
2307     const auto &graph_value_nodes_json = graph_json[kGraphValueNodes];
2308     for (const auto &iter : graph_value_nodes_json) {
2309       auto node = context.FindBackNodeByBackName(iter);
2310       MS_EXCEPTION_IF_NULL(node);
2311       auto value_node = node->cast<ValueNodePtr>();
2312       MS_EXCEPTION_IF_NULL(value_node);
2313       graph->AddValueNodeToGraph(value_node);
2314     }
2315   }
2316 #ifndef ENABLE_SECURITY
2317   if (graph_json.contains(kSummaryNodes)) {
2318     const auto &summary_nodes_json = graph_json[kSummaryNodes];
2319     std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes;
2320     for (const auto &iter : summary_nodes_json) {
2321       const auto &first = iter.at(0);
2322       const auto &name = iter.at(kIndexOne);
2323       auto node = context.FindBackNodeByBackName(name);
2324       MS_EXCEPTION_IF_NULL(node);
2325       const auto &index = iter.at(kIndexTwo);
2326       summary_nodes[first] = std::make_pair(node, index);
2327     }
2328     graph->set_summary_nodes(summary_nodes);
2329   }
2330 #endif
2331   MS_LOG(INFO) << "Handle graph " << graph->ToString() << " complex attr success.";
2332 }
2333 
ParseKernelGraphNodesAndAttrs(const nlohmann::json & model_json)2334 bool KernelGraphMgr::ParseKernelGraphNodesAndAttrs(const nlohmann::json &model_json) {
2335   auto &context = CompileCacheContext::GetInstance();
2336   for (auto &[graph_name, graph_json] : model_json.items()) {
2337     MS_LOG(DEBUG) << "Parse graph " << graph_name << " nodes and attrs.";
2338     KernelGraphPtr graph = graphs_.at(graph_json[kGraphId]);
2339     MS_EXCEPTION_IF_NULL(graph);
2340     HandleGraphInputsOutputs(graph_json, graph.get());
2341     const auto &back_to_front = graph_json[kBackendFrontAnf];
2342     for (const auto &[back_node_name, front_node_name] : back_to_front.items()) {
2343       auto back_node = context.FindBackNodeByBackName(back_node_name);
2344       if (!back_node) {
2345         MS_LOG(EXCEPTION) << "The backend node is nullptr, its unique name is " << back_node_name;
2346       }
2347       auto front_node = context.FindFrontNodeByFrontName(front_node_name);
2348       if (!front_node) {
2349         MS_LOG(EXCEPTION) << "The frontend node is nullptr, its unique name is " << front_node_name;
2350       }
2351       if (graph->FrontendNodeExistInFrontBackendMap(front_node)) {
2352         if (graph->BackendNodeExistInFrontBackendMap(back_node)) {
2353           continue;
2354         }
2355         auto old_back_node = graph->GetBackendAnfByFrontAnf(front_node);
2356         graph->FrontBackendMapAdd(old_back_node, back_node);
2357       } else {
2358         graph->FrontBackendMapAdd(front_node, back_node);
2359       }
2360     }
2361     HandleGraphSimpleAttr(graph_json, graph.get());
2362     HandleGraphComplexAttr(graphs_, graph_json, graph.get());
2363 
2364     FuncGraphManagerPtr manager = MakeManager({graph});
2365     if (manager) {
2366       manager->AddFuncGraph(graph);
2367       graph->set_manager(manager);
2368     }
2369     graph->SetInputNodes();
2370     SetInputNodeUsage(graph, manager);
2371     graph->SetOptimizerFlag();
2372   }
2373   return true;
2374 }
2375 
ResetGetNextSharedName(const FuncGraphPtr & graph)2376 void ResetGetNextSharedName(const FuncGraphPtr &graph) {
2377   auto &config_mgr = ConfigManager::GetInstance();
2378   auto queue_name = config_mgr.QueueName();
2379   auto cnodes = graph->GetOrderedCnodes();
2380   for (const auto &cnode : cnodes) {
2381     auto prim = GetValuePtr<Primitive>(cnode->input(0));
2382     if (prim != nullptr && prim->HasAttr("shared_name")) {
2383       prim->set_attr("shared_name", MakeValue(queue_name));
2384       break;
2385     }
2386   }
2387 }
2388 
ConstructKernelGraph(std::vector<KernelGraphPtr> * all_out_graph)2389 std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(std::vector<KernelGraphPtr> *all_out_graph) {
2390   MS_LOG(WARNING) << "Use the compile cache to construct kernel graph, Be aware of correctness risks.";
2391   auto &context = CompileCacheContext::GetInstance();
2392   auto frontend_graph = context.FrontGraph();
2393   if (!frontend_graph) {
2394     MS_LOG(EXCEPTION) << "The frontend graph is null";
2395   }
2396   auto cache_path = context.GetBackendGraphCachePath(frontend_graph);
2397   std::string json_path = cache_path + kJsonSuffix;
2398   nlohmann::json model_json;
2399   auto load_json_success = LoadJson(json_path, &model_json);
2400   if (!load_json_success) {
2401     MS_LOG(EXCEPTION) << "Load json file " << json_path << " failed.";
2402   }
2403   // construct kernel graph and its params that exist correspond frontend param
2404   mindspore::HashMap<std::string, AnfNodePtr> name_to_node;
2405   (void)std::for_each(name_to_params_.begin(), name_to_params_.end(), [&name_to_node](const auto &ele) {
2406     if (!ele.second.expired()) {
2407       name_to_node[ele.first] = ele.second.lock();
2408     }
2409   });
2410   MS_LOG(DEBUG) << "Construct kernel graph and its params that exist correspond frontend param.";
2411   for (size_t i = 0; i < model_json.size(); i++) {
2412     auto kernel_graph = NewKernelGraph();
2413     all_out_graph->push_back(kernel_graph);
2414     const auto &graph_name = kernel_graph->ToString();
2415     if (!model_json.contains(graph_name)) {
2416       MS_LOG(EXCEPTION) << "Load graph " << graph_name << " from json failed.";
2417     }
2418     auto &graph_json = model_json[graph_name];
2419     if (!graph_json.contains(kCorrespondFrontendGraph)) {
2420       continue;
2421     }
2422     const auto &front_graph_name = graph_json[kCorrespondFrontendGraph];
2423     const auto &front_graph_node = context.FindFrontNodeByFrontName(front_graph_name);
2424     FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(front_graph_node);
2425     MS_EXCEPTION_IF_NULL(fg);
2426     front_backend_graph_map_[fg.get()] = kernel_graph;
2427     if (graph_json.contains(kBackendParamToFrontendParamIndex)) {
2428       const auto &backend_param_to_frontend_param_index = graph_json[kBackendParamToFrontendParamIndex];
2429       const auto &frontend_graph_params = fg->parameters();
2430       for (const auto &[param_unique_name, index] : backend_param_to_frontend_param_index.items()) {
2431         const auto front_param = frontend_graph_params.at(index);
2432         MS_EXCEPTION_IF_NULL(front_param);
2433         MS_LOG(DEBUG) << "Start create new node, old node = " << front_param->DebugString()
2434                       << ", new node = " << param_unique_name;
2435         auto new_parameter = CreateNewParameter(front_param, kernel_graph.get());
2436         kernel_graph->FrontBackendMapAdd(front_param, new_parameter);
2437         name_to_node[param_unique_name] = new_parameter;
2438       }
2439     }
2440   }
2441 
2442   std::vector<FuncGraphPtr> graphs_for_load;
2443   (void)(std::transform(all_out_graph->begin(), all_out_graph->end(), std::back_inserter(graphs_for_load),
2444                         [](const KernelGraphPtr &g) { return g; }));
2445 
2446   MindIRLoader mindir_loader;
2447   std::string mindir_path = cache_path + kMindIrSuffix;
2448   auto real_path = Common::CreatePrefixPath(mindir_path, true);
2449   if (!CheckPath(real_path)) {
2450     MS_LOG(EXCEPTION) << "The mindir path is " << mindir_path << ", and it is a invalid path!";
2451   }
2452   if (!mindir_loader.LoadMindIR(real_path.value(), graphs_for_load, &name_to_node)) {
2453     MS_LOG(EXCEPTION) << "Load mindir from " << real_path.value() << " failed.";
2454   }
2455   (void)std::for_each(name_to_node.begin(), name_to_node.end(), [](const auto &ele) {
2456     auto node = ele.second;
2457     MS_EXCEPTION_IF_NULL(node);
2458     if (node->template isa<Parameter>()) {
2459       name_to_params_[ele.first] = std::weak_ptr<AnfNode>(node);
2460     }
2461   });
2462   context.SetBackNameToBackNode(name_to_node);
2463   // the value of attr "shared_name" will changed every time, so reset GetNext shared_name
2464   ResetGetNextSharedName(all_out_graph->front());
2465   if (!ParseKernelGraphNodesAndAttrs(model_json)) {
2466     MS_LOG(EXCEPTION) << "Parse kernel graph nodes and attrs failed.";
2467   }
2468 
2469 #ifdef ENABLE_DUMP_IR
2470   auto ms_context = MsContext::GetInstance();
2471   MS_EXCEPTION_IF_NULL(ms_context);
2472   if (ms_context->CanDump(kIntroductory)) {
2473     for (const auto &iter : graphs_) {
2474       auto dump_name = std::string("loaded_") + iter.second->ToString() + ".ir";
2475       DumpIR(dump_name, iter.second);
2476     }
2477   }
2478 #endif
2479   MS_LOG(WARNING)
2480     << "Use the compile cache to construct kernel graph success, and will execute the preprocess before run directly.";
2481   return all_out_graph->front();
2482 }
2483 
SetInputNodeUsage(const KernelGraphPtr & graph,const FuncGraphManagerPtr & manager) const2484 void KernelGraphMgr::SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager) const {
2485   MS_EXCEPTION_IF_NULL(graph);
2486   MS_EXCEPTION_IF_NULL(manager);
2487   auto input_nodes = graph->input_nodes();
2488   for (auto &input_node : input_nodes) {
2489     MS_EXCEPTION_IF_NULL(input_node);
2490     if (input_node->isa<Parameter>()) {
2491       auto node_ptr = input_node->cast<ParameterPtr>();
2492       MS_EXCEPTION_IF_NULL(node_ptr);
2493       if (!IsUsedByRealKernel(manager, input_node, graph->graph_id())) {
2494         node_ptr->SetNotUsedByRealKernelInGraph(graph->graph_id());
2495       }
2496       auto shape = node_ptr->Shape();
2497       MS_EXCEPTION_IF_NULL(shape);
2498       if (shape->isa<abstract::Shape>() && shape->IsDynamic()) {
2499         node_ptr->set_has_dynamic_shape(true);
2500       }
2501       if (input_node->abstract() != nullptr && input_node->abstract()->isa<abstract::AbstractSequence>()) {
2502         // If the parameter is dynamic sequence, it is regard as dynamic shape.
2503         const auto &tuple_abs = input_node->abstract()->cast<abstract::AbstractSequencePtr>();
2504         MS_EXCEPTION_IF_NULL(tuple_abs);
2505         if (tuple_abs->dynamic_len()) {
2506           MS_LOG(INFO) << "Input node:" << input_node->DebugString() << " set dynamic flag to true";
2507           node_ptr->set_has_dynamic_shape(true);
2508         }
2509       }
2510     }
2511   }
2512 }
2513 
2514 namespace {
CNodeFirstInputIsPrimitive(const AnfNodePtr & node)2515 bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
2516   if (node == nullptr) {
2517     return false;
2518   }
2519   auto cnode = node->cast<CNodePtr>();
2520   if (cnode == nullptr) {
2521     return false;
2522   }
2523   auto prim = cnode->input(kAnfPrimitiveIndex);
2524   if (prim == nullptr || !IsValueNode<Primitive>(prim)) {
2525     return false;
2526   }
2527   return true;
2528 }
2529 
ExtendNodeUsers(const FuncGraphManagerPtr & front_func_graph_manager,const AnfNodePtr & front_node)2530 std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
2531                                         const AnfNodePtr &front_node) {
2532   MS_EXCEPTION_IF_NULL(front_func_graph_manager);
2533   auto &users = front_func_graph_manager->node_users()[front_node];
2534   std::vector<AnfNodePtr> result;
2535   for (auto &user : users) {
2536     if (common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend) ||
2537         common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimLoad)) {
2538       auto depend_cnode = user.first->cast<CNodePtr>();
2539       if (depend_cnode == nullptr) {
2540         continue;
2541       }
2542       if (front_node != depend_cnode->input(1)) {
2543         continue;
2544       }
2545       auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
2546       (void)result.insert(result.cend(), res.cbegin(), res.cend());
2547     } else if (common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
2548       auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
2549       (void)result.insert(result.cend(), res.cbegin(), res.cend());
2550     } else {
2551       (void)result.emplace_back(user.first);
2552     }
2553   }
2554   return result;
2555 }
2556 
GetSupportedInternalNode(const AnfNodePtr & front_node)2557 AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
2558   MS_EXCEPTION_IF_NULL(front_node);
2559   if (!front_node->isa<CNode>()) {
2560     return nullptr;
2561   }
2562   if (AnfUtils::IsRealKernel(front_node)) {
2563     return front_node;
2564   }
2565   if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
2566     return front_node;
2567   }
2568   if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimMakeTuple)) {
2569     auto cnode = front_node->cast<CNodePtr>();
2570     MS_EXCEPTION_IF_NULL(cnode);
2571     auto &inputs = cnode->inputs();
2572     if (inputs.size() > 1) {
2573       return GetSupportedInternalNode(inputs[1]);
2574     }
2575   }
2576   if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
2577     auto cnode = front_node->cast<CNodePtr>();
2578     MS_EXCEPTION_IF_NULL(cnode);
2579     auto &inputs = cnode->inputs();
2580     if (inputs.size() >= kDependInputSize) {
2581       return GetSupportedInternalNode(inputs[kRealInputIndexInDepend]);
2582     }
2583   }
2584   return nullptr;
2585 }
2586 
IsUnusedInternlOutput(const AnfNodePtr & user)2587 bool IsUnusedInternlOutput(const AnfNodePtr &user) {
2588   if (!CNodeFirstInputIsPrimitive(user)) {
2589     return true;
2590   }
2591   if (IsPrimitiveCNode(user, prim::kPrimSwitch) || IsPrimitiveCNode(user, prim::kPrimSwitchLayer)) {
2592     return true;
2593   }
2594   if (!AnfUtils::IsRealKernel(user)) {
2595     return true;
2596   }
2597   return false;
2598 }
2599 }  // namespace
2600 
2601 constexpr auto kMixTarget = "MixTarget";
2602 constexpr auto kNoTarget = "NoTarget";
AddPartialParametersMap(const AnfNodePtr & partial_node)2603 std::string KernelGraphMgr::AddPartialParametersMap(const AnfNodePtr &partial_node) {
2604   MS_EXCEPTION_IF_NULL(partial_node);
2605   auto iter = partial_target_map_.find(partial_node);
2606   if (iter != partial_target_map_.end()) {
2607     return iter->second;
2608   }
2609   auto partial_cnode = partial_node->cast<CNodePtr>();
2610   MS_EXCEPTION_IF_NULL(partial_cnode);
2611   auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
2612   // If graph is nullptr, it means that the funcgraph in the partial node is a deadnode, and the processing is skipped.
2613   if (partial_graph == nullptr) {
2614     return kNoTarget;
2615   }
2616   auto parameters = partial_graph->parameters();
2617   auto partial_inputs = partial_cnode->inputs();
2618   const size_t kNonParameterNum = 2;
2619   if (parameters.size() + kNonParameterNum != partial_inputs.size()) {
2620     return kMixTarget;
2621   }
2622   for (size_t i = 0; i < parameters.size(); ++i) {
2623     partial_parameters_map_[parameters[i]] = partial_inputs[kNonParameterNum + i];
2624   }
2625   auto graph_nodes = TopoSort(partial_graph->get_return());
2626   std::string graph_target = kNoTarget;
2627   for (auto &node : graph_nodes) {
2628     if (!node->isa<CNode>()) {
2629       continue;
2630     }
2631     if (!AnfUtils::IsRealKernel(node)) {
2632       continue;
2633     }
2634     std::string cur_target = GetCNodeTarget(node);
2635     if (graph_target == kNoTarget) {
2636       graph_target = cur_target;
2637     }
2638     if (graph_target != cur_target) {
2639       graph_target = kMixTarget;
2640       break;
2641     }
2642   }
2643   (void)partial_target_map_.emplace(std::pair<AnfNodePtr, std::string>(partial_node, graph_target));
2644   return graph_target;
2645 }
2646 
2647 namespace {
IsNeedAddPartialParameter(const AnfNodePtr & user,const std::string & kernel_target,const std::shared_ptr<KernelGraph> & graph)2648 bool IsNeedAddPartialParameter(const AnfNodePtr &user, const std::string &kernel_target,
2649                                const std::shared_ptr<KernelGraph> &graph) {
2650   // If the flag is enable, it means the graph would run in subgraph sink mode, the real parameter on partial
2651   // cannot share the same device address with the formal parameter.
2652   MS_EXCEPTION_IF_NULL(graph);
2653   return common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
2654          !ExistGraphCaller(user) && (!graph->has_flag(kFlagEnableZeroCopyInGraph));
2655 }
2656 }  // namespace
2657 
HandleInternalOutput(const AnfNodePtr & input_front_node,const AnfNodePtr & backend_node,const FuncGraphManagerPtr & front_func_graph_manager,const std::shared_ptr<KernelGraph> & backend_graph)2658 void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
2659                                           const FuncGraphManagerPtr &front_func_graph_manager,
2660                                           const std::shared_ptr<KernelGraph> &backend_graph) {
2661   MS_EXCEPTION_IF_NULL(backend_graph);
2662   auto front_node = GetSupportedInternalNode(input_front_node);
2663   if (front_node == nullptr) {
2664     return;
2665   }
2666   auto front_real_kernel_pair = common::AnfAlgo::VisitKernel(front_node, 0);
2667   auto backend_real_kernel_pair = common::AnfAlgo::VisitKernel(backend_node, 0);
2668   auto backend_real_kernel = backend_real_kernel_pair.first;
2669   if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) {
2670     return;
2671   }
2672   auto front_real_kernel = front_real_kernel_pair.first;
2673   std::string kernel_target = GetCNodeTarget(front_real_kernel);
2674   bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
2675   bool unique_target = true;
2676   if (internal_output && common::AnfAlgo::IsNopNode(front_real_kernel)) {
2677     auto pre_node_pair = common::AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
2678     auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
2679     if (pre_node_target != kernel_target) {
2680       unique_target = false;
2681     }
2682   }
2683   if (internal_output) {
2684     auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
2685     for (auto &user : users) {
2686       if (IsNeedAddPartialParameter(user, kernel_target, backend_graph)) {
2687         auto partial_target = AddPartialParametersMap(user);
2688         if (partial_target != kNoTarget && partial_target != kernel_target) {
2689           unique_target = false;
2690         }
2691         continue;
2692       }
2693       if (common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
2694         continue;
2695       }
2696       if (IsUnusedInternlOutput(user)) {
2697         internal_output = false;
2698         break;
2699       }
2700       if (kernel_target != GetCNodeTarget(user)) {
2701         unique_target = false;
2702       }
2703     }
2704   }
2705   if (internal_output) {
2706     MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString()
2707                  << ", unique_target: " << unique_target;
2708     backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
2709   }
2710 }
2711 
ConstructOutput(const AnfNodePtrList & outputs,const std::shared_ptr<KernelGraph> & graph)2712 CNodePtr KernelGraphMgr::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
2713   MS_EXCEPTION_IF_NULL(graph);
2714   std::vector<AnfNodePtr> output_args;
2715   for (const auto &output : outputs) {
2716     MS_EXCEPTION_IF_NULL(output);
2717     MS_LOG(INFO) << "Output:" << output->DebugString();
2718   }
2719   auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
2720     auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
2721     if (backend_anf != nullptr) {
2722       auto context_ptr = MsContext::GetInstance();
2723       MS_EXCEPTION_IF_NULL(context_ptr);
2724       if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
2725         return backend_anf;
2726       }
2727 
2728       MS_EXCEPTION_IF_NULL(out);
2729       auto out_func_graph = out->func_graph();
2730       MS_EXCEPTION_IF_NULL(out_func_graph);
2731       auto out_func_graph_manager = out_func_graph->manager();
2732       if (out_func_graph_manager == nullptr) {
2733         return backend_anf;
2734       }
2735       HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph);
2736       return backend_anf;
2737     }
2738     MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
2739   };
2740   output_args.push_back(mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())));
2741   (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
2742                        [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
2743   auto output_node = graph->NewCNode(output_args);
2744   MS_EXCEPTION_IF_NULL(output_node);
2745   // Create abstract for output maketuple node.
2746   AbstractBasePtrList output_abs_list;
2747   const auto &inputs = output_node->inputs();
2748   (void)std::transform(
2749     inputs.begin() + 1, inputs.end(), std::back_inserter(output_abs_list), [](const AnfNodePtr &input) {
2750       return input->abstract() == nullptr ? std::make_shared<abstract::AbstractNone>() : input->abstract();
2751     });
2752   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(output_abs_list);
2753   MS_EXCEPTION_IF_NULL(abstract_tuple);
2754   output_node->set_abstract(abstract_tuple);
2755   return output_node;
2756 }
2757 
NewPynativeKernelGraph()2758 KernelGraphPtr KernelGraphMgr::NewPynativeKernelGraph() {
2759   auto graph = std::make_shared<KernelGraph>();
2760   graph->set_is_from_pynative(true);
2761   MS_EXCEPTION_IF_NULL(graph);
2762   graph->set_graph_id(pynative_graph_sum_);
2763   pynative_graph_sum_++;
2764   return graph;
2765 }
2766 
NewKernelGraph()2767 KernelGraphPtr KernelGraphMgr::NewKernelGraph() {
2768   auto graph = std::make_shared<KernelGraph>();
2769   MS_EXCEPTION_IF_NULL(graph);
2770   SetKernelGraphId(graph);
2771   return graph;
2772 }
2773 
SetKernelGraphId(const KernelGraphPtr & kernel_graph)2774 void KernelGraphMgr::SetKernelGraphId(const KernelGraphPtr &kernel_graph) {
2775   MS_EXCEPTION_IF_NULL(kernel_graph);
2776   if (graph_sum_ >= kPynativeGraphIdStart) {
2777     MS_LOG(EXCEPTION) << "The graph id in graph mode must be less than " << kPynativeGraphIdStart << ", but it is "
2778                       << graph_sum_;
2779   }
2780   kernel_graph->set_graph_id(graph_sum_);
2781   graphs_[graph_sum_++] = kernel_graph;
2782 }
2783 
UnifyMindIR(const KernelGraphPtr & graph)2784 void KernelGraphMgr::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIR(graph); }
2785 
2786 namespace {
CopyCNodeInfo(const FuncGraphPtr & func_graph,const uint32_t & target_graph_id,const AnfNodePtr & ori_node,const AnfNodePtr & new_node)2787 void CopyCNodeInfo(const FuncGraphPtr &func_graph, const uint32_t &target_graph_id, const AnfNodePtr &ori_node,
2788                    const AnfNodePtr &new_node) {
2789   MS_EXCEPTION_IF_NULL(new_node);
2790   MS_EXCEPTION_IF_NULL(ori_node);
2791   auto kernel_info = dynamic_cast<device::KernelInfo *>(new_node->kernel_info());
2792   // deep copy kernel info
2793   if (kernel_info != nullptr && kernel_info->has_build_info()) {
2794     // some check
2795     MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->MutableKernelMod() == nullptr,
2796                                "Inline ERROR: " + ori_node->DebugString() + ", kernel mod is not nullptr");
2797     MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->output_address_list().empty(),
2798                                "Inline ERROR: " + ori_node->DebugString() + ", output_address_list is not empty");
2799     MS_EXCEPTION_IF_CHECK_FAIL(kernel_info->workspace_address_list().empty(),
2800                                "Inline ERROR: " + ori_node->DebugString() + ", workspace_address_list is not empty");
2801 
2802     auto new_kernel_info = std::make_shared<device::KernelInfo>();
2803     auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(
2804       AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(new_node));
2805     MS_EXCEPTION_IF_NULL(builder);
2806     MS_EXCEPTION_IF_NULL(new_kernel_info);
2807     new_kernel_info->set_select_kernel_build_info(builder->Build());
2808     new_kernel_info->set_graph_id(target_graph_id);
2809     new_kernel_info->set_feature_map_flag(kernel_info->is_feature_map());
2810     new_kernel_info->set_ref_map(false, kernel_info->out_in_ref_map());
2811     new_node->set_kernel_info(new_kernel_info);
2812   } else {
2813     auto new_kernel_info = std::make_shared<device::KernelInfo>();
2814     new_node->set_kernel_info(new_kernel_info);
2815   }
2816   if (ori_node->isa<CNode>()) {
2817     auto ori_cnode = ori_node->cast<CNodePtr>();
2818     if (common::AnfAlgo::HasNodeAttr(kAttrIsUBFusionOp, ori_cnode) &&
2819         common::AnfAlgo::GetNodeAttr<bool>(ori_node, kAttrIsUBFusionOp)) {
2820       // already done fusion compile
2821       auto ori_full_name = ori_cnode->fullname_with_scope();
2822       common::AnfAlgo::SetNodeAttr(kAttrOriFusionName, MakeValue(ori_full_name), new_node);
2823     }
2824     common::AnfAlgo::SetNodeAttr(kAttrNeedInline, MakeValue(ori_node->fullname_with_scope()), new_node);
2825     common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node);
2826   }
2827 }
2828 
UpdateConditionNodePair(const KernelGraphPtr & kernel_graph,const KernelGraphPtr & target_kernel_graph,const mindspore::HashMap<AnfNodePtr,AnfNodePtr> & condition_node_map)2829 void UpdateConditionNodePair(const KernelGraphPtr &kernel_graph, const KernelGraphPtr &target_kernel_graph,
2830                              const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &condition_node_map) {
2831   MS_EXCEPTION_IF_NULL(kernel_graph);
2832   const auto &gather_to_switch = kernel_graph->condition_gather_to_switch();
2833   for (const auto &pair : gather_to_switch) {
2834     MS_EXCEPTION_IF_NULL(pair.first);
2835     MS_EXCEPTION_IF_NULL(pair.second);
2836     const auto &gather_iter = condition_node_map.find(pair.first);
2837     const auto &switch_iter = condition_node_map.find(pair.second);
2838     if (gather_iter == condition_node_map.end() || switch_iter == condition_node_map.end()) {
2839       MS_LOG(EXCEPTION) << "Failed to get new gather node:" << pair.first->fullname_with_scope()
2840                         << " or switch node:" << pair.second->fullname_with_scope()
2841                         << " in graph:" << kernel_graph->ToString();
2842     }
2843     MS_EXCEPTION_IF_NULL(gather_iter->second);
2844     MS_EXCEPTION_IF_NULL(switch_iter->second);
2845     if (target_kernel_graph != nullptr) {
2846       target_kernel_graph->AddConditionGatherSwitchPair(gather_iter->second, switch_iter->second);
2847       MS_LOG(INFO) << "Add condition node pair:" << gather_iter->second->fullname_with_scope()
2848                    << " and:" << switch_iter->second->fullname_with_scope()
2849                    << " for graph:" << target_kernel_graph->ToString();
2850       const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(pair.second);
2851       if (front_node == nullptr) {
2852         MS_LOG(WARNING) << "Failed to get front node by backend node:" << pair.second->DebugString()
2853                         << " in graph:" << kernel_graph->ToString();
2854         continue;
2855       }
2856       target_kernel_graph->FrontBackendMapAdd(front_node, switch_iter->second);
2857     }
2858   }
2859 }
2860 }  // namespace
2861 
DoInline(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & func_graph_args,const ScopePtr & scope,const uint32_t & target_graph_id,const std::map<session::AnfWithOutIndex,session::AnfWithOutIndex> & ref_map,const KernelGraphPtr & graph,bool is_switch_inline)2862 AnfNodePtr KernelGraphMgr::DoInline(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
2863                                     const AnfNodePtrList &func_graph_args, const ScopePtr &scope,
2864                                     const uint32_t &target_graph_id,
2865                                     const std::map<session::AnfWithOutIndex, session::AnfWithOutIndex> &ref_map,
2866                                     const KernelGraphPtr &graph, bool is_switch_inline) {
2867   MS_EXCEPTION_IF_NULL(func_graph);
2868   MS_EXCEPTION_IF_NULL(graph);
2869   MS_EXCEPTION_IF_NULL(target_func_graph);
2870   KernelGraphPtr target_kernel_graph = nullptr;
2871   if (target_func_graph->isa<KernelGraph>()) {
2872     target_kernel_graph = target_func_graph->cast<KernelGraphPtr>();
2873   }
2874   Cloner cloner({}, false);
2875   if (scope != nullptr) {
2876     cloner.set_scope(scope);
2877   }
2878   cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
2879   auto node_list = TopoSort(func_graph->output());
2880   mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_node_map;
2881   for (auto &ori_node : node_list) {
2882     MS_EXCEPTION_IF_NULL(ori_node);
2883     if (ori_node->isa<Parameter>()) {
2884       continue;
2885     }
2886     auto new_node = cloner[ori_node];
2887     MS_EXCEPTION_IF_NULL(new_node);
2888     if (new_node->isa<ValueNode>()) {
2889       auto value_node = new_node->cast<ValueNodePtr>();
2890       MS_EXCEPTION_IF_NULL(value_node);
2891       graph->AddValueNodeToGraph(value_node);
2892     }
2893     // Add sub graph kernel for switch inline kernel graph.
2894     if (new_node->isa<CNode>() && target_kernel_graph != nullptr && is_switch_inline) {
2895       MS_LOG(DEBUG) << "Add inline sub graph for kernel:" << new_node->fullname_with_scope()
2896                     << " graph:" << func_graph->ToString();
2897       std::string sub_graph_name = func_graph->ToString();
2898       if (func_graph->isa<KernelGraph>()) {
2899         const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
2900         MS_EXCEPTION_IF_NULL(kernel_graph);
2901         const auto &sub_graph_iter = kernel_graph->inline_sub_graph_kernels().find(ori_node);
2902         if (sub_graph_iter != kernel_graph->inline_sub_graph_kernels().end()) {
2903           sub_graph_name = sub_graph_iter->second;
2904         }
2905       }
2906       target_kernel_graph->AddInlineSubgraphKernel(new_node, sub_graph_name);
2907       if (common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionGather) ||
2908           common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionSwitch)) {
2909         condition_node_map[ori_node] = new_node;
2910       }
2911     }
2912     CopyCNodeInfo(func_graph, target_graph_id, ori_node, new_node);
2913   }
2914   // Collect condition gather node and condition switch node.
2915   if (func_graph->isa<KernelGraph>() && is_switch_inline) {
2916     const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
2917     MS_EXCEPTION_IF_NULL(kernel_graph);
2918     UpdateConditionNodePair(kernel_graph, target_kernel_graph, condition_node_map);
2919   }
2920 
2921   for (const auto &kv : ref_map) {
2922     auto final_pair = kv.first;
2923     auto origin_pair = kv.second;
2924     final_pair.first = cloner[final_pair.first];
2925     origin_pair.first = cloner[origin_pair.first];
2926     auto new_origin_pair = common::AnfAlgo::VisitKernel(origin_pair.first, origin_pair.second);
2927     graph->AddRefCorrespondPairs(final_pair, new_origin_pair);
2928   }
2929   return cloner[func_graph->output()];
2930 }
2931 }  // namespace session
2932 }  // namespace mindspore
2933