1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/convert_custom_for_ge.h"
17 #include "mindspore/core/ops/custom.h"
18 #include "utils/anf_utils.h"
19 #include "utils/file_utils.h"
20 #include "kernel/framework_utils.h"
21 #include "include/common/utils/anfalgo.h"
22 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
23 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
24
25 namespace mindspore::graphkernel {
CreateCustomOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)26 AnfNodePtr ConvertCustomForGE::CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
27 MS_EXCEPTION_IF_NULL(func_graph);
28 MS_EXCEPTION_IF_NULL(cnode);
29 auto op = std::make_shared<ops::Custom>();
30 op->set_type("GraphKernel");
31 auto custom_prim = op->GetPrim();
32 auto inputs = cnode->inputs();
33 inputs[0] = NewValueNode(custom_prim);
34 auto custom_cnode = func_graph->NewCNode(inputs);
35 auto json_name = node_json_name_[cnode->cast<AnfNodePtr>()];
36 auto input_num = AnfUtils::GetInputTensorNum(cnode);
37 auto output_num = AnfUtils::GetOutputTensorNum(cnode);
38 std::vector<std::string> input_names;
39 for (size_t i = 0; i < input_num; ++i) {
40 input_names.push_back("x" + std::to_string(i));
41 }
42 std::vector<std::string> output_names;
43 for (size_t i = 0; i < output_num; ++i) {
44 output_names.push_back("y" + std::to_string(i));
45 }
46
47 std::ostringstream oss;
48 oss << "Fused_x" << input_num << "_y" << output_num;
49 std::string op_tye = oss.str();
50 custom_prim->set_attr("reg_op_name", MakeValue(op_tye));
51 custom_prim->set_attr("info_path", MakeValue(info_dir_ + "/" + json_name + ".info"));
52 custom_prim->set_attr("input_names", MakeValue(input_names));
53 custom_prim->set_attr("output_names", MakeValue(output_names));
54 custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
55 custom_cnode->set_abstract(cnode->abstract()->Clone());
56 return custom_cnode;
57 }
58
CreateInfoDir()59 void ConvertCustomForGE::CreateInfoDir() {
60 static std::string rank_id = common::GetEnv("RANK_ID");
61 std::string dir;
62 if (rank_id.empty()) {
63 dir = "./akg_kernel_meta";
64 } else {
65 dir = "./rank_" + rank_id + "/akg_kernel_meta";
66 }
67 auto dir_path = FileUtils::CreateNotExistDirs(dir);
68 if (!dir_path.has_value()) {
69 MS_LOG(EXCEPTION) << "Failed to create directory: '" << dir << "'";
70 }
71 info_dir_ = dir_path.value();
72 }
73
SaveNodesInfo(const AnfNodePtrList & nodes)74 void ConvertCustomForGE::SaveNodesInfo(const AnfNodePtrList &nodes) {
75 CreateInfoDir();
76 DumpOption option;
77 option.get_target_info = true;
78 std::set<std::string> unique_kernel_name;
79 for (const auto &node : nodes) {
80 graphkernel::GraphKernelJsonGenerator graph_kernel_json_generator(option);
81 FuncGraphPtr sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
82 MS_EXCEPTION_IF_NULL(sub_func_graph);
83 auto mng = sub_func_graph->manager();
84 if (mng == nullptr) {
85 mng = Manage(sub_func_graph, true);
86 sub_func_graph->set_manager(mng);
87 }
88 std::vector<AnfNodePtr> node_list, input_list, output_list;
89 GkUtils::GetValidKernelNodes(sub_func_graph, &node_list, &input_list, &output_list);
90 graph_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list);
91 auto kernel_name = graph_kernel_json_generator.kernel_name();
92 node_json_name_[node] = kernel_name;
93 if (!unique_kernel_name.insert(kernel_name).second) {
94 continue;
95 }
96 kernel::SaveJsonInfo(kernel_name, graph_kernel_json_generator.kernel_json_str(), info_dir_ + "/");
97 }
98 }
99
Run(const FuncGraphPtr & func_graph)100 bool ConvertCustomForGE::Run(const FuncGraphPtr &func_graph) {
101 MS_EXCEPTION_IF_NULL(func_graph);
102 auto mng = func_graph->manager();
103 if (mng == nullptr) {
104 mng = Manage(func_graph, true);
105 func_graph->set_manager(mng);
106 }
107 auto node_list = GkUtils::GetGraphKernelNodes(func_graph);
108 // 1. generate node info file
109 SaveNodesInfo(node_list);
110 // 2. convert fused node to Custom op
111 for (const auto &node : node_list) {
112 auto cnode = node->cast<CNodePtr>();
113 auto custom_cnode = CreateCustomOp(func_graph, cnode);
114 if (custom_cnode == nullptr) {
115 MS_LOG(EXCEPTION) << "Create custom op failed for " << cnode->fullname_with_scope();
116 }
117 mng->Replace(node, custom_cnode);
118 }
119 return !node_list.empty();
120 }
121 } // namespace mindspore::graphkernel
122