• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/convert_custom_for_ge.h"
17 #include "mindspore/core/ops/custom.h"
18 #include "utils/anf_utils.h"
19 #include "utils/file_utils.h"
20 #include "kernel/framework_utils.h"
21 #include "include/common/utils/anfalgo.h"
22 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
23 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
24 
25 namespace mindspore::graphkernel {
CreateCustomOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)26 AnfNodePtr ConvertCustomForGE::CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
27   MS_EXCEPTION_IF_NULL(func_graph);
28   MS_EXCEPTION_IF_NULL(cnode);
29   auto op = std::make_shared<ops::Custom>();
30   op->set_type("GraphKernel");
31   auto custom_prim = op->GetPrim();
32   auto inputs = cnode->inputs();
33   inputs[0] = NewValueNode(custom_prim);
34   auto custom_cnode = func_graph->NewCNode(inputs);
35   auto json_name = node_json_name_[cnode->cast<AnfNodePtr>()];
36   auto input_num = AnfUtils::GetInputTensorNum(cnode);
37   auto output_num = AnfUtils::GetOutputTensorNum(cnode);
38   std::vector<std::string> input_names;
39   for (size_t i = 0; i < input_num; ++i) {
40     input_names.push_back("x" + std::to_string(i));
41   }
42   std::vector<std::string> output_names;
43   for (size_t i = 0; i < output_num; ++i) {
44     output_names.push_back("y" + std::to_string(i));
45   }
46 
47   std::ostringstream oss;
48   oss << "Fused_x" << input_num << "_y" << output_num;
49   std::string op_tye = oss.str();
50   custom_prim->set_attr("reg_op_name", MakeValue(op_tye));
51   custom_prim->set_attr("info_path", MakeValue(info_dir_ + "/" + json_name + ".info"));
52   custom_prim->set_attr("input_names", MakeValue(input_names));
53   custom_prim->set_attr("output_names", MakeValue(output_names));
54   custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
55   custom_cnode->set_abstract(cnode->abstract()->Clone());
56   return custom_cnode;
57 }
58 
CreateInfoDir()59 void ConvertCustomForGE::CreateInfoDir() {
60   static std::string rank_id = common::GetEnv("RANK_ID");
61   std::string dir;
62   if (rank_id.empty()) {
63     dir = "./akg_kernel_meta";
64   } else {
65     dir = "./rank_" + rank_id + "/akg_kernel_meta";
66   }
67   auto dir_path = FileUtils::CreateNotExistDirs(dir);
68   if (!dir_path.has_value()) {
69     MS_LOG(EXCEPTION) << "Failed to create directory: '" << dir << "'";
70   }
71   info_dir_ = dir_path.value();
72 }
73 
SaveNodesInfo(const AnfNodePtrList & nodes)74 void ConvertCustomForGE::SaveNodesInfo(const AnfNodePtrList &nodes) {
75   CreateInfoDir();
76   DumpOption option;
77   option.get_target_info = true;
78   std::set<std::string> unique_kernel_name;
79   for (const auto &node : nodes) {
80     graphkernel::GraphKernelJsonGenerator graph_kernel_json_generator(option);
81     FuncGraphPtr sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
82     MS_EXCEPTION_IF_NULL(sub_func_graph);
83     auto mng = sub_func_graph->manager();
84     if (mng == nullptr) {
85       mng = Manage(sub_func_graph, true);
86       sub_func_graph->set_manager(mng);
87     }
88     std::vector<AnfNodePtr> node_list, input_list, output_list;
89     GkUtils::GetValidKernelNodes(sub_func_graph, &node_list, &input_list, &output_list);
90     graph_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list);
91     auto kernel_name = graph_kernel_json_generator.kernel_name();
92     node_json_name_[node] = kernel_name;
93     if (!unique_kernel_name.insert(kernel_name).second) {
94       continue;
95     }
96     kernel::SaveJsonInfo(kernel_name, graph_kernel_json_generator.kernel_json_str(), info_dir_ + "/");
97   }
98 }
99 
Run(const FuncGraphPtr & func_graph)100 bool ConvertCustomForGE::Run(const FuncGraphPtr &func_graph) {
101   MS_EXCEPTION_IF_NULL(func_graph);
102   auto mng = func_graph->manager();
103   if (mng == nullptr) {
104     mng = Manage(func_graph, true);
105     func_graph->set_manager(mng);
106   }
107   auto node_list = GkUtils::GetGraphKernelNodes(func_graph);
108   // 1. generate node info file
109   SaveNodesInfo(node_list);
110   // 2. convert fused node to Custom op
111   for (const auto &node : node_list) {
112     auto cnode = node->cast<CNodePtr>();
113     auto custom_cnode = CreateCustomOp(func_graph, cnode);
114     if (custom_cnode == nullptr) {
115       MS_LOG(EXCEPTION) << "Create custom op failed for " << cnode->fullname_with_scope();
116     }
117     mng->Replace(node, custom_cnode);
118   }
119   return !node_list.empty();
120 }
121 }  // namespace mindspore::graphkernel
122