• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifdef ENABLE_AKG
17 #include "backend/common/graph_kernel/graph_kernel_build.h"
18 
19 #include <fstream>
20 #include <utility>
21 #include <string>
22 #include <map>
23 #include <unordered_set>
24 #include <algorithm>
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
29 #include "backend/common/graph_kernel/graph_kernel_helper.h"
30 #include "backend/common/graph_kernel/graph_kernel_flags.h"
31 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
32 #include "kernel/graph_kernel/graph_kernel_builder_manager.h"
33 #include "backend/common/graph_kernel/symbol_engine/multi_symbol_engine.h"
34 
35 namespace mindspore::graphkernel {
36 namespace {
GetTopoValidNodes(const FuncGraphPtr & func_graph,CNodePtrList * topo_valid_nodes)37 void GetTopoValidNodes(const FuncGraphPtr &func_graph, CNodePtrList *topo_valid_nodes) {
38   MS_EXCEPTION_IF_NULL(func_graph);
39   MS_EXCEPTION_IF_NULL(topo_valid_nodes);
40   auto nodes = TopoSort(func_graph->get_return());
41   for (auto &node : nodes) {
42     if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
43       continue;
44     }
45     auto cnode = node->cast<CNodePtr>();
46     MS_EXCEPTION_IF_NULL(cnode);
47     topo_valid_nodes->push_back(cnode);
48   }
49 }
50 
IsAkgOp(const AnfNodePtr & node)51 bool IsAkgOp(const AnfNodePtr &node) {
52   if (node == nullptr || !node->isa<CNode>()) {
53     return false;
54   }
55   static std::unordered_set<std::string> ops{"UnPadAkg", "PadAkg", "ElemAny"};
56   auto name = AnfUtils::GetCNodeName(node);
57   return ops.find(name) != ops.end();
58 }
59 }  // namespace
60 
Split(const FuncGraphPtr & func_graph)61 bool SafeSplitSchemer::Split(const FuncGraphPtr &func_graph) {
62   MS_EXCEPTION_IF_NULL(func_graph);
63   Run(func_graph);
64   return !split_plan_.empty();
65 }
66 
Run(const FuncGraphPtr & func_graph)67 void SafeSplitSchemer::Run(const FuncGraphPtr &func_graph) {
68   auto mng = func_graph->manager();
69   if (mng == nullptr) {
70     mng = Manage(func_graph, true);
71     func_graph->set_manager(mng);
72   }
73   SplitNodes(func_graph);
74   if (split_plan_.size() != need_inline_.size() || split_plan_.empty() || (split_plan_.size() == 1 && !NeedInline(0))) {
75     split_plan_.clear();
76     need_inline_.clear();
77     return;
78   }
79   GroupReturnNode(func_graph);
80 }
81 
SplitNodes(const FuncGraphPtr & func_graph)82 void SafeSplitSchemer::SplitNodes(const FuncGraphPtr &func_graph) {
83   CNodePtrList topo_valid_nodes;
84   GetTopoValidNodes(func_graph, &topo_valid_nodes);
85   for (size_t i = 0; i < topo_valid_nodes.size(); ++i) {
86     const auto &node = topo_valid_nodes[i];
87     node_group_[node] = i;
88   }
89 
90   std::map<size_t, AnfNodePtrList> group_nodes;
91   // Nodes with same group id will stay in the same group.
92   for (const auto &node : topo_valid_nodes) {
93     auto group_id = node_group_[node];
94     group_nodes[group_id].push_back(node);
95   }
96 
97   node_group_.clear();
98   for (const auto &it : group_nodes) {
99     for (const auto &node : it.second) {
100       node_group_[node] = split_plan_.size();
101     }
102     split_plan_.push_back(it.second);
103     // If a group has >= 2 nodes or AKG specific node, then this group will stay in a sub graph(need_inline = 0).
104     if (it.second.size() > 1 || (it.second.size() == 1 && IsAkgOp(it.second.back()))) {
105       need_inline_.push_back(0);
106     } else {
107       need_inline_.push_back(1);
108     }
109   }
110 }
111 
Init()112 void GraphKernelBuild::Init() {
113   // Init KernelMeta.
114   if (bin_map_ == nullptr) {
115     bin_map_ = kernel::KernelMeta::GetInstance();
116     if (!bin_map_->initialized()) {
117       bin_map_->Initialize();
118     }
119   }
120 
121   // Init AkgKernelBuilder.
122   auto device_type = Callback::Instance()->GetTargetFromContext();
123   bool is_akg_v2 = (GraphKernelFlags::GetInstance().kernel_generator == "AKG_V2");
124   kernel_builder_ = kernel::GraphKernelBuildManager::Instance().GetGraphKernelBuilder(device_type, is_akg_v2);
125   if (kernel_builder_ == nullptr) {
126     MS_EXCEPTION(UnknownError) << "Can't find corresponding kernel builder for device: " << device_type
127                                << ", and kernel_generator flag to be: "
128                                << GraphKernelFlags::GetInstance().kernel_generator << " .";
129   }
130 }
131 
Process(const FuncGraphPtr & func_graph,int iter)132 bool GraphKernelBuild::Process(const FuncGraphPtr &func_graph, int iter) {
133   bool changed = false;
134   std::vector<kernel::JsonNodePair> nodes;
135   CollectNodes(func_graph, &nodes);
136   // No nodes need to be compiled.
137   if (nodes.empty()) {
138     MS_LOG(DEBUG) << "There are no Akg kernel to be compiled.";
139     return changed;
140   }
141   // Update cache before compiling. Some nodes may already have compiled cache(e.g. compiled from previous network
142   // running), these nodes do not need to be compiled again.
143   auto need_compile_nodes = CollectNotCachedNodes(nodes);
144   MS_LOG(INFO) << "Iter " << iter << ": Total Akg kernel number is " << nodes.size() << ", "
145                << need_compile_nodes.size() << " of them need to be compiled, and "
146                << (nodes.size() - need_compile_nodes.size()) << " of them use the compilation cache.";
147   // Parallel compile.
148   ParallelBuild(need_compile_nodes);
149   // Update cache after compiling. Nodes that still not have compile cache means they compiled failed.
150   changed = SplitNodesByKernelCompiler(nodes);
151   auto remaining_nodes = CollectNotCachedNodes(need_compile_nodes);
152   // Split nodes that compile failed.
153   changed = changed || SplitNodes(remaining_nodes);
154   return changed;
155 }
156 
CollectNode(const AnfNodePtr & node) const157 kernel::JsonNodePair GraphKernelBuild::CollectNode(const AnfNodePtr &node) const {
158   FuncGraphPtr sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
159   MS_EXCEPTION_IF_NULL(sub_func_graph);
160   auto mng = sub_func_graph->manager();
161   if (mng == nullptr) {
162     mng = Manage(sub_func_graph, true);
163     sub_func_graph->set_manager(mng);
164   }
165   AnfNodePtrList node_list;
166   AnfNodePtrList input_list;
167   AnfNodePtrList output_list;
168   kernel::GetValidKernelNodes(sub_func_graph, &node_list, &input_list, &output_list);
169   DumpOption option;
170   option.get_target_info = true;
171   option.save_ptr_address = true;
172   GraphKernelJsonGenerator graph_kernel_json_generator(option);
173   if (sub_func_graph->symbol_engine() != nullptr) {
174     graph_kernel_json_generator.set_symbol_engine(sub_func_graph->symbol_engine());
175   } else if (common::AnfAlgo::IsDynamicShape(node)) {
176     symshape::MultiSymbolEngine::BuildSubEngine(node);
177     graph_kernel_json_generator.set_symbol_engine(sub_func_graph->symbol_engine());
178   }
179   if (!graph_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) {
180     MS_EXCEPTION(UnknownError) << "Collect op info file failed. op[" << node->fullname_with_scope() << "].";
181   }
182   auto cnode = node->cast<CNodePtr>();
183   MS_EXCEPTION_IF_NULL(cnode);
184   sub_func_graph->set_attr("info_name", MakeValue(graph_kernel_json_generator.kernel_name()));
185   return std::make_pair(graph_kernel_json_generator, node);
186 }
187 
CollectNodes(const FuncGraphPtr & func_graph,std::vector<kernel::JsonNodePair> * nodes) const188 void GraphKernelBuild::CollectNodes(const FuncGraphPtr &func_graph, std::vector<kernel::JsonNodePair> *nodes) const {
189   if (func_graph == nullptr) {
190     return;
191   }
192   MS_EXCEPTION_IF_NULL(nodes);
193   auto manager = func_graph->manager();
194   MS_EXCEPTION_IF_NULL(manager);
195   auto todo = TopoSort(func_graph->get_return());
196   for (auto iter = todo.crbegin(); iter != todo.crend(); ++iter) {
197     auto node = *iter;
198     // Only processes graph kernel node
199     if (node == nullptr || !common::AnfAlgo::IsGraphKernel(node) || AnfAlgo::GetKernelMod(node) != nullptr) {
200       continue;
201     }
202     auto json_node = CollectNode(node);
203     nodes->push_back(json_node);
204   }
205 }
206 
GetGraphKernelNodeName(const AnfNodePtr & node)207 std::string GetGraphKernelNodeName(const AnfNodePtr &node) {
208   auto cnode = node->cast<CNodePtr>();
209   MS_EXCEPTION_IF_NULL(cnode);
210   auto func_graph = GetCNodeFuncGraph(cnode);
211   if (func_graph->has_attr(kAttrNodeName)) {
212     return GetValue<std::string>(func_graph->get_attr(kAttrNodeName));
213   }
214   return std::string();
215 }
216 
CollectNotCachedNodes(const std::vector<kernel::JsonNodePair> & nodes)217 std::vector<kernel::JsonNodePair> GraphKernelBuild::CollectNotCachedNodes(
218   const std::vector<kernel::JsonNodePair> &nodes) {
219   MS_EXCEPTION_IF_NULL(bin_map_);
220   MS_EXCEPTION_IF_NULL(kernel_builder_);
221   std::vector<kernel::JsonNodePair> res;
222   for (const auto &[json_generator, node] : nodes) {
223     if (node == nullptr) {
224       continue;
225     }
226     // Skip node that already set kernel mod(created from compile cache).
227     if (AnfAlgo::GetKernelMod(node) != nullptr) {
228       MS_LOG(DEBUG) << "Skip node that already set kernel mod: " << json_generator.kernel_name();
229       continue;
230     }
231     auto kernel_name = json_generator.kernel_name();
232     // Skip node that already has cache.
233     if (kernel_pack_.find(kernel_name) != kernel_pack_.end()) {
234       kernel_builder_->SetKernelMod(kernel_pack_[kernel_name], json_generator, node);
235       MS_LOG(DEBUG) << "Set cached kernel for node [" << node->fullname_with_scope() << "] with kernel name ["
236                     << kernel_name << "]";
237       continue;
238     }
239 
240     std::string split_kernel_name = GetGraphKernelNodeName(node);
241     // Check whether node is a split node and already has cache.
242     if (kernel_pack_.find(split_kernel_name) != kernel_pack_.end()) {
243       kernel_builder_->SetKernelMod(kernel_pack_[split_kernel_name], json_generator, node);
244       MS_LOG(DEBUG) << "Set cached kernel for node [" << node->fullname_with_scope() << "] with kernel node name ["
245                     << split_kernel_name << "]";
246       continue;
247     }
248 
249     std::string split_result_path = bin_map_->kernel_meta_path() + kernel_name + "_split" + kernel::kJsonSuffix;
250     std::ifstream split_result_json(split_result_path);
251     // Split json file exits, which means the node is split by the kernel compiler.
252     if (split_result_json.is_open()) {
253       // check split result
254       MS_LOG(DEBUG) << "The node is split by the kernel compiler: " << kernel_name;
255       split_result_json.close();
256       continue;
257     }
258 
259     std::string json_path = bin_map_->kernel_meta_path() + kernel_name + kernel::kJsonSuffix;
260     std::ifstream kernel_json(json_path);
261     // Json file not exits, which means the node does not have cache.
262     if (!kernel_json.is_open()) {
263       std::string split_json_path = bin_map_->kernel_meta_path() + split_kernel_name + kernel::kJsonSuffix;
264       std::ifstream split_kernel_json(split_json_path);
265       if (!split_kernel_json.is_open()) {
266         (void)res.emplace_back(json_generator, node);
267         MS_LOG(DEBUG) << "The node does not have cache as the json [" << node->fullname_with_scope()
268                       << "] with kernel name [" << kernel_name << "] is not found.";
269         continue;
270       } else {
271         MS_LOG(DEBUG) << "The node has cache with split kernel as the json [" << node->fullname_with_scope()
272                       << "] with kernel name [" << split_kernel_name << "] is found.";
273         kernel_name = split_kernel_name;
274         json_path = split_json_path;
275         split_kernel_json.close();
276       }
277     } else {
278       kernel_json.close();
279     }
280 
281     // For GPU and CPU, we need to insert json path to bin_map_(KernelMeta) first, otherwise SearchKernelCache will
282     // fail.
283     (void)bin_map_->Insert(kernel_name, json_path);
284     auto cached_kernel_pack = kernel_builder_->SearchKernelCache(kernel_name);
285     // Node cache found.
286     if (cached_kernel_pack != nullptr) {
287       kernel_pack_[kernel_name] = cached_kernel_pack;
288       kernel_builder_->SetKernelMod(cached_kernel_pack, json_generator, node);
289       MS_LOG(DEBUG) << "Set cached kernel for node [" << node->fullname_with_scope() << "] with kernel name ["
290                     << kernel_name << "]";
291       continue;
292     }
293     // Node cache not found.
294     (void)res.emplace_back(json_generator, node);
295   }
296   return res;
297 }
298 
ParallelBuild(const std::vector<kernel::JsonNodePair> & nodes)299 void GraphKernelBuild::ParallelBuild(const std::vector<kernel::JsonNodePair> &nodes) {
300   std::vector<kernel::JsonNodePair> uniq_nodes;
301   std::unordered_set<std::string> kernel_names;
302   // GraphKernelBuildKernelBuilder::ParallelBuild can not process duplicate nodes, so we need to filter these nodes
303   // first.
304   for (const auto &[json_generator, node] : nodes) {
305     const auto &kernel_name = json_generator.kernel_name();
306     if (kernel_names.find(kernel_name) == kernel_names.end()) {
307       (void)kernel_names.insert(kernel_name);
308       (void)uniq_nodes.emplace_back(json_generator, node);
309     }
310   }
311   if (!uniq_nodes.empty()) {
312     MS_EXCEPTION_IF_NULL(kernel_builder_);
313     (void)kernel_builder_->ParallelBuild(uniq_nodes);
314   }
315 }
316 
SplitNodes(const std::vector<kernel::JsonNodePair> & nodes)317 bool GraphKernelBuild::SplitNodes(const std::vector<kernel::JsonNodePair> &nodes) {
318   bool result = false;
319   std::unordered_set<std::string> kernel_names;
320   for (const auto &[json_generator, node] : nodes) {
321     const auto &kernel_name = json_generator.kernel_name();
322     // Print kernel name of nodes that compile failed.
323     if (kernel_names.find(kernel_name) == kernel_names.end()) {
324       (void)kernel_names.insert(kernel_name);
325       MS_LOG(WARNING) << "Nodes that with kernel name [" << kernel_name
326                       << "] do not have compile cache after compiling and will be split.";
327     }
328     MS_EXCEPTION_IF_NULL(node);
329     auto cnode = node->cast<CNodePtr>();
330     MS_EXCEPTION_IF_NULL(cnode);
331     if (!splitter_.TrySplit(cnode)) {
332       // This means the compiled failed node also can not be split.
333       MS_LOG(EXCEPTION) << "Node [" << node->fullname_with_scope() << "] with kernel name [" << kernel_name
334                         << "] compiled failed and can not be split.";
335     }
336     result = true;
337   }
338   return result;
339 }
340 
SplitNodesByKernelCompiler(const std::vector<kernel::JsonNodePair> & nodes)341 bool GraphKernelBuild::SplitNodesByKernelCompiler(const std::vector<kernel::JsonNodePair> &nodes) {
342   MS_EXCEPTION_IF_NULL(bin_map_);
343   MS_EXCEPTION_IF_NULL(kernel_builder_);
344   bool result = false;
345   KernelCompilerGraphKernelSplitter compiler_splitter_;
346   for (const auto &[json_generator, node] : nodes) {
347     if (node == nullptr) {
348       continue;
349     }
350     const auto &kernel_name = json_generator.kernel_name();
351 
352     std::string split_json_path = bin_map_->kernel_meta_path() + kernel_name + "_split" + kernel::kJsonSuffix;
353     std::ifstream kernel_split_json(split_json_path);
354     // Json file not exits, which means the node is not split by the kernel compiler.
355     if (!kernel_split_json.is_open()) {
356       continue;
357     }
358     nlohmann::json js;
359     kernel_split_json >> js;
360     kernel_split_json.close();
361 
362     std::map<std::string, AnfNodePtr> address_node_map_ = json_generator.address_node_map();
363     compiler_splitter_.SetAddressNodeMap(address_node_map_);
364     compiler_splitter_.SetJson(js.dump());
365     auto cnode = node->cast<CNodePtr>();
366     MS_EXCEPTION_IF_NULL(cnode);
367     auto ori_sub_func_graph = GetCNodeFuncGraph(cnode);
368     ori_sub_func_graph->set_attr(kAttrNodeName, MakeValue(kernel_name));
369     if (!compiler_splitter_.TrySplit(cnode)) {
370       // This means the compiled failed node also can not be split.
371       MS_LOG(EXCEPTION) << "Node [" << node->fullname_with_scope() << "] with kernel name [" << kernel_name
372                         << "] compiled failed and can not be split.";
373     }
374     result = true;
375   }
376   return result;
377 }
378 
Run(const FuncGraphPtr & func_graph)379 bool GraphKernelBuild::Run(const FuncGraphPtr &func_graph) {
380   MS_EXCEPTION_IF_NULL(func_graph);
381   auto mng = func_graph->manager();
382   if (mng == nullptr) {
383     mng = Manage(func_graph, true);
384     func_graph->set_manager(mng);
385   }
386 
387   Init();
388 
389   bool changed = false;
390   bool need_traverse = true;
391   int iter = 1;
392   while (need_traverse) {
393     need_traverse = Process(func_graph, iter);
394     iter++;
395     changed = need_traverse || changed;
396     if (need_traverse) {
397       mng->RemoveRoots();
398       mng->KeepRoots({func_graph});
399     }
400   }
401 
402   return changed;
403 }
404 }  // namespace mindspore::graphkernel
405 #endif
406