1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifdef ENABLE_AKG
17 #include "backend/common/graph_kernel/graph_kernel_build.h"
18
19 #include <fstream>
20 #include <utility>
21 #include <string>
22 #include <map>
23 #include <unordered_set>
24 #include <algorithm>
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
29 #include "backend/common/graph_kernel/graph_kernel_helper.h"
30 #include "backend/common/graph_kernel/graph_kernel_flags.h"
31 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
32 #include "kernel/graph_kernel/graph_kernel_builder_manager.h"
33 #include "backend/common/graph_kernel/symbol_engine/multi_symbol_engine.h"
34
35 namespace mindspore::graphkernel {
36 namespace {
GetTopoValidNodes(const FuncGraphPtr & func_graph,CNodePtrList * topo_valid_nodes)37 void GetTopoValidNodes(const FuncGraphPtr &func_graph, CNodePtrList *topo_valid_nodes) {
38 MS_EXCEPTION_IF_NULL(func_graph);
39 MS_EXCEPTION_IF_NULL(topo_valid_nodes);
40 auto nodes = TopoSort(func_graph->get_return());
41 for (auto &node : nodes) {
42 if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
43 continue;
44 }
45 auto cnode = node->cast<CNodePtr>();
46 MS_EXCEPTION_IF_NULL(cnode);
47 topo_valid_nodes->push_back(cnode);
48 }
49 }
50
IsAkgOp(const AnfNodePtr & node)51 bool IsAkgOp(const AnfNodePtr &node) {
52 if (node == nullptr || !node->isa<CNode>()) {
53 return false;
54 }
55 static std::unordered_set<std::string> ops{"UnPadAkg", "PadAkg", "ElemAny"};
56 auto name = AnfUtils::GetCNodeName(node);
57 return ops.find(name) != ops.end();
58 }
59 } // namespace
60
Split(const FuncGraphPtr & func_graph)61 bool SafeSplitSchemer::Split(const FuncGraphPtr &func_graph) {
62 MS_EXCEPTION_IF_NULL(func_graph);
63 Run(func_graph);
64 return !split_plan_.empty();
65 }
66
Run(const FuncGraphPtr & func_graph)67 void SafeSplitSchemer::Run(const FuncGraphPtr &func_graph) {
68 auto mng = func_graph->manager();
69 if (mng == nullptr) {
70 mng = Manage(func_graph, true);
71 func_graph->set_manager(mng);
72 }
73 SplitNodes(func_graph);
74 if (split_plan_.size() != need_inline_.size() || split_plan_.empty() || (split_plan_.size() == 1 && !NeedInline(0))) {
75 split_plan_.clear();
76 need_inline_.clear();
77 return;
78 }
79 GroupReturnNode(func_graph);
80 }
81
SplitNodes(const FuncGraphPtr & func_graph)82 void SafeSplitSchemer::SplitNodes(const FuncGraphPtr &func_graph) {
83 CNodePtrList topo_valid_nodes;
84 GetTopoValidNodes(func_graph, &topo_valid_nodes);
85 for (size_t i = 0; i < topo_valid_nodes.size(); ++i) {
86 const auto &node = topo_valid_nodes[i];
87 node_group_[node] = i;
88 }
89
90 std::map<size_t, AnfNodePtrList> group_nodes;
91 // Nodes with same group id will stay in the same group.
92 for (const auto &node : topo_valid_nodes) {
93 auto group_id = node_group_[node];
94 group_nodes[group_id].push_back(node);
95 }
96
97 node_group_.clear();
98 for (const auto &it : group_nodes) {
99 for (const auto &node : it.second) {
100 node_group_[node] = split_plan_.size();
101 }
102 split_plan_.push_back(it.second);
103 // If a group has >= 2 nodes or AKG specific node, then this group will stay in a sub graph(need_inline = 0).
104 if (it.second.size() > 1 || (it.second.size() == 1 && IsAkgOp(it.second.back()))) {
105 need_inline_.push_back(0);
106 } else {
107 need_inline_.push_back(1);
108 }
109 }
110 }
111
Init()112 void GraphKernelBuild::Init() {
113 // Init KernelMeta.
114 if (bin_map_ == nullptr) {
115 bin_map_ = kernel::KernelMeta::GetInstance();
116 if (!bin_map_->initialized()) {
117 bin_map_->Initialize();
118 }
119 }
120
121 // Init AkgKernelBuilder.
122 auto device_type = Callback::Instance()->GetTargetFromContext();
123 bool is_akg_v2 = (GraphKernelFlags::GetInstance().kernel_generator == "AKG_V2");
124 kernel_builder_ = kernel::GraphKernelBuildManager::Instance().GetGraphKernelBuilder(device_type, is_akg_v2);
125 if (kernel_builder_ == nullptr) {
126 MS_EXCEPTION(UnknownError) << "Can't find corresponding kernel builder for device: " << device_type
127 << ", and kernel_generator flag to be: "
128 << GraphKernelFlags::GetInstance().kernel_generator << " .";
129 }
130 }
131
Process(const FuncGraphPtr & func_graph,int iter)132 bool GraphKernelBuild::Process(const FuncGraphPtr &func_graph, int iter) {
133 bool changed = false;
134 std::vector<kernel::JsonNodePair> nodes;
135 CollectNodes(func_graph, &nodes);
136 // No nodes need to be compiled.
137 if (nodes.empty()) {
138 MS_LOG(DEBUG) << "There are no Akg kernel to be compiled.";
139 return changed;
140 }
141 // Update cache before compiling. Some nodes may already have compiled cache(e.g. compiled from previous network
142 // running), these nodes do not need to be compiled again.
143 auto need_compile_nodes = CollectNotCachedNodes(nodes);
144 MS_LOG(INFO) << "Iter " << iter << ": Total Akg kernel number is " << nodes.size() << ", "
145 << need_compile_nodes.size() << " of them need to be compiled, and "
146 << (nodes.size() - need_compile_nodes.size()) << " of them use the compilation cache.";
147 // Parallel compile.
148 ParallelBuild(need_compile_nodes);
149 // Update cache after compiling. Nodes that still not have compile cache means they compiled failed.
150 changed = SplitNodesByKernelCompiler(nodes);
151 auto remaining_nodes = CollectNotCachedNodes(need_compile_nodes);
152 // Split nodes that compile failed.
153 changed = changed || SplitNodes(remaining_nodes);
154 return changed;
155 }
156
CollectNode(const AnfNodePtr & node) const157 kernel::JsonNodePair GraphKernelBuild::CollectNode(const AnfNodePtr &node) const {
158 FuncGraphPtr sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
159 MS_EXCEPTION_IF_NULL(sub_func_graph);
160 auto mng = sub_func_graph->manager();
161 if (mng == nullptr) {
162 mng = Manage(sub_func_graph, true);
163 sub_func_graph->set_manager(mng);
164 }
165 AnfNodePtrList node_list;
166 AnfNodePtrList input_list;
167 AnfNodePtrList output_list;
168 kernel::GetValidKernelNodes(sub_func_graph, &node_list, &input_list, &output_list);
169 DumpOption option;
170 option.get_target_info = true;
171 option.save_ptr_address = true;
172 GraphKernelJsonGenerator graph_kernel_json_generator(option);
173 if (sub_func_graph->symbol_engine() != nullptr) {
174 graph_kernel_json_generator.set_symbol_engine(sub_func_graph->symbol_engine());
175 } else if (common::AnfAlgo::IsDynamicShape(node)) {
176 symshape::MultiSymbolEngine::BuildSubEngine(node);
177 graph_kernel_json_generator.set_symbol_engine(sub_func_graph->symbol_engine());
178 }
179 if (!graph_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) {
180 MS_EXCEPTION(UnknownError) << "Collect op info file failed. op[" << node->fullname_with_scope() << "].";
181 }
182 auto cnode = node->cast<CNodePtr>();
183 MS_EXCEPTION_IF_NULL(cnode);
184 sub_func_graph->set_attr("info_name", MakeValue(graph_kernel_json_generator.kernel_name()));
185 return std::make_pair(graph_kernel_json_generator, node);
186 }
187
CollectNodes(const FuncGraphPtr & func_graph,std::vector<kernel::JsonNodePair> * nodes) const188 void GraphKernelBuild::CollectNodes(const FuncGraphPtr &func_graph, std::vector<kernel::JsonNodePair> *nodes) const {
189 if (func_graph == nullptr) {
190 return;
191 }
192 MS_EXCEPTION_IF_NULL(nodes);
193 auto manager = func_graph->manager();
194 MS_EXCEPTION_IF_NULL(manager);
195 auto todo = TopoSort(func_graph->get_return());
196 for (auto iter = todo.crbegin(); iter != todo.crend(); ++iter) {
197 auto node = *iter;
198 // Only processes graph kernel node
199 if (node == nullptr || !common::AnfAlgo::IsGraphKernel(node) || AnfAlgo::GetKernelMod(node) != nullptr) {
200 continue;
201 }
202 auto json_node = CollectNode(node);
203 nodes->push_back(json_node);
204 }
205 }
206
GetGraphKernelNodeName(const AnfNodePtr & node)207 std::string GetGraphKernelNodeName(const AnfNodePtr &node) {
208 auto cnode = node->cast<CNodePtr>();
209 MS_EXCEPTION_IF_NULL(cnode);
210 auto func_graph = GetCNodeFuncGraph(cnode);
211 if (func_graph->has_attr(kAttrNodeName)) {
212 return GetValue<std::string>(func_graph->get_attr(kAttrNodeName));
213 }
214 return std::string();
215 }
216
CollectNotCachedNodes(const std::vector<kernel::JsonNodePair> & nodes)217 std::vector<kernel::JsonNodePair> GraphKernelBuild::CollectNotCachedNodes(
218 const std::vector<kernel::JsonNodePair> &nodes) {
219 MS_EXCEPTION_IF_NULL(bin_map_);
220 MS_EXCEPTION_IF_NULL(kernel_builder_);
221 std::vector<kernel::JsonNodePair> res;
222 for (const auto &[json_generator, node] : nodes) {
223 if (node == nullptr) {
224 continue;
225 }
226 // Skip node that already set kernel mod(created from compile cache).
227 if (AnfAlgo::GetKernelMod(node) != nullptr) {
228 MS_LOG(DEBUG) << "Skip node that already set kernel mod: " << json_generator.kernel_name();
229 continue;
230 }
231 auto kernel_name = json_generator.kernel_name();
232 // Skip node that already has cache.
233 if (kernel_pack_.find(kernel_name) != kernel_pack_.end()) {
234 kernel_builder_->SetKernelMod(kernel_pack_[kernel_name], json_generator, node);
235 MS_LOG(DEBUG) << "Set cached kernel for node [" << node->fullname_with_scope() << "] with kernel name ["
236 << kernel_name << "]";
237 continue;
238 }
239
240 std::string split_kernel_name = GetGraphKernelNodeName(node);
241 // Check whether node is a split node and already has cache.
242 if (kernel_pack_.find(split_kernel_name) != kernel_pack_.end()) {
243 kernel_builder_->SetKernelMod(kernel_pack_[split_kernel_name], json_generator, node);
244 MS_LOG(DEBUG) << "Set cached kernel for node [" << node->fullname_with_scope() << "] with kernel node name ["
245 << split_kernel_name << "]";
246 continue;
247 }
248
249 std::string split_result_path = bin_map_->kernel_meta_path() + kernel_name + "_split" + kernel::kJsonSuffix;
250 std::ifstream split_result_json(split_result_path);
251 // Split json file exits, which means the node is split by the kernel compiler.
252 if (split_result_json.is_open()) {
253 // check split result
254 MS_LOG(DEBUG) << "The node is split by the kernel compiler: " << kernel_name;
255 split_result_json.close();
256 continue;
257 }
258
259 std::string json_path = bin_map_->kernel_meta_path() + kernel_name + kernel::kJsonSuffix;
260 std::ifstream kernel_json(json_path);
261 // Json file not exits, which means the node does not have cache.
262 if (!kernel_json.is_open()) {
263 std::string split_json_path = bin_map_->kernel_meta_path() + split_kernel_name + kernel::kJsonSuffix;
264 std::ifstream split_kernel_json(split_json_path);
265 if (!split_kernel_json.is_open()) {
266 (void)res.emplace_back(json_generator, node);
267 MS_LOG(DEBUG) << "The node does not have cache as the json [" << node->fullname_with_scope()
268 << "] with kernel name [" << kernel_name << "] is not found.";
269 continue;
270 } else {
271 MS_LOG(DEBUG) << "The node has cache with split kernel as the json [" << node->fullname_with_scope()
272 << "] with kernel name [" << split_kernel_name << "] is found.";
273 kernel_name = split_kernel_name;
274 json_path = split_json_path;
275 split_kernel_json.close();
276 }
277 } else {
278 kernel_json.close();
279 }
280
281 // For GPU and CPU, we need to insert json path to bin_map_(KernelMeta) first, otherwise SearchKernelCache will
282 // fail.
283 (void)bin_map_->Insert(kernel_name, json_path);
284 auto cached_kernel_pack = kernel_builder_->SearchKernelCache(kernel_name);
285 // Node cache found.
286 if (cached_kernel_pack != nullptr) {
287 kernel_pack_[kernel_name] = cached_kernel_pack;
288 kernel_builder_->SetKernelMod(cached_kernel_pack, json_generator, node);
289 MS_LOG(DEBUG) << "Set cached kernel for node [" << node->fullname_with_scope() << "] with kernel name ["
290 << kernel_name << "]";
291 continue;
292 }
293 // Node cache not found.
294 (void)res.emplace_back(json_generator, node);
295 }
296 return res;
297 }
298
ParallelBuild(const std::vector<kernel::JsonNodePair> & nodes)299 void GraphKernelBuild::ParallelBuild(const std::vector<kernel::JsonNodePair> &nodes) {
300 std::vector<kernel::JsonNodePair> uniq_nodes;
301 std::unordered_set<std::string> kernel_names;
302 // GraphKernelBuildKernelBuilder::ParallelBuild can not process duplicate nodes, so we need to filter these nodes
303 // first.
304 for (const auto &[json_generator, node] : nodes) {
305 const auto &kernel_name = json_generator.kernel_name();
306 if (kernel_names.find(kernel_name) == kernel_names.end()) {
307 (void)kernel_names.insert(kernel_name);
308 (void)uniq_nodes.emplace_back(json_generator, node);
309 }
310 }
311 if (!uniq_nodes.empty()) {
312 MS_EXCEPTION_IF_NULL(kernel_builder_);
313 (void)kernel_builder_->ParallelBuild(uniq_nodes);
314 }
315 }
316
SplitNodes(const std::vector<kernel::JsonNodePair> & nodes)317 bool GraphKernelBuild::SplitNodes(const std::vector<kernel::JsonNodePair> &nodes) {
318 bool result = false;
319 std::unordered_set<std::string> kernel_names;
320 for (const auto &[json_generator, node] : nodes) {
321 const auto &kernel_name = json_generator.kernel_name();
322 // Print kernel name of nodes that compile failed.
323 if (kernel_names.find(kernel_name) == kernel_names.end()) {
324 (void)kernel_names.insert(kernel_name);
325 MS_LOG(WARNING) << "Nodes that with kernel name [" << kernel_name
326 << "] do not have compile cache after compiling and will be split.";
327 }
328 MS_EXCEPTION_IF_NULL(node);
329 auto cnode = node->cast<CNodePtr>();
330 MS_EXCEPTION_IF_NULL(cnode);
331 if (!splitter_.TrySplit(cnode)) {
332 // This means the compiled failed node also can not be split.
333 MS_LOG(EXCEPTION) << "Node [" << node->fullname_with_scope() << "] with kernel name [" << kernel_name
334 << "] compiled failed and can not be split.";
335 }
336 result = true;
337 }
338 return result;
339 }
340
SplitNodesByKernelCompiler(const std::vector<kernel::JsonNodePair> & nodes)341 bool GraphKernelBuild::SplitNodesByKernelCompiler(const std::vector<kernel::JsonNodePair> &nodes) {
342 MS_EXCEPTION_IF_NULL(bin_map_);
343 MS_EXCEPTION_IF_NULL(kernel_builder_);
344 bool result = false;
345 KernelCompilerGraphKernelSplitter compiler_splitter_;
346 for (const auto &[json_generator, node] : nodes) {
347 if (node == nullptr) {
348 continue;
349 }
350 const auto &kernel_name = json_generator.kernel_name();
351
352 std::string split_json_path = bin_map_->kernel_meta_path() + kernel_name + "_split" + kernel::kJsonSuffix;
353 std::ifstream kernel_split_json(split_json_path);
354 // Json file not exits, which means the node is not split by the kernel compiler.
355 if (!kernel_split_json.is_open()) {
356 continue;
357 }
358 nlohmann::json js;
359 kernel_split_json >> js;
360 kernel_split_json.close();
361
362 std::map<std::string, AnfNodePtr> address_node_map_ = json_generator.address_node_map();
363 compiler_splitter_.SetAddressNodeMap(address_node_map_);
364 compiler_splitter_.SetJson(js.dump());
365 auto cnode = node->cast<CNodePtr>();
366 MS_EXCEPTION_IF_NULL(cnode);
367 auto ori_sub_func_graph = GetCNodeFuncGraph(cnode);
368 ori_sub_func_graph->set_attr(kAttrNodeName, MakeValue(kernel_name));
369 if (!compiler_splitter_.TrySplit(cnode)) {
370 // This means the compiled failed node also can not be split.
371 MS_LOG(EXCEPTION) << "Node [" << node->fullname_with_scope() << "] with kernel name [" << kernel_name
372 << "] compiled failed and can not be split.";
373 }
374 result = true;
375 }
376 return result;
377 }
378
Run(const FuncGraphPtr & func_graph)379 bool GraphKernelBuild::Run(const FuncGraphPtr &func_graph) {
380 MS_EXCEPTION_IF_NULL(func_graph);
381 auto mng = func_graph->manager();
382 if (mng == nullptr) {
383 mng = Manage(func_graph, true);
384 func_graph->set_manager(mng);
385 }
386
387 Init();
388
389 bool changed = false;
390 bool need_traverse = true;
391 int iter = 1;
392 while (need_traverse) {
393 need_traverse = Process(func_graph, iter);
394 iter++;
395 changed = need_traverse || changed;
396 if (need_traverse) {
397 mng->RemoveRoots();
398 mng->KeepRoots({func_graph});
399 }
400 }
401
402 return changed;
403 }
404 } // namespace mindspore::graphkernel
405 #endif
406