1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/add_stitch_atomic_clean_gpu.h"
17
18 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
19 #include "backend/common/graph_kernel/graph_kernel_helper.h"
20 #include "include/backend/kernel_graph.h"
21 #include "include/common/utils/utils.h"
22 #include "kernel/framework_utils.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "utils/log_adapter.h"
27
28 namespace mindspore::graphkernel {
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & inplace_infos)29 void StitchAtomicCleanInserter::CorrectKernelBuildInfo(
30 const AnfNodePtr &composite_node, const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &inplace_infos) {
31 // Change kernel build info.
32 auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
33 MS_EXCEPTION_IF_NULL(kernel_info);
34 const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
35 auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
36 auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
37 auto origin_processor = origin_kernel_build_info->processor();
38
39 std::vector<std::string> new_inputs_format = origin_kernel_build_info->GetAllInputFormats();
40 std::vector<TypeId> new_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
41 std::vector<std::string> new_outputs_format;
42 std::vector<TypeId> new_outputs_type;
43 for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
44 new_outputs_format.push_back(origin_outputs_format[i]);
45 new_outputs_type.push_back(origin_outputs_type[i]);
46 }
47
48 auto kernel_with_index = common::AnfAlgo::VisitKernel(inplace_infos[0].second, 0);
49 new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
50 new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
51
52 auto new_selected_info = BuildSelectKernelBuildInfo(new_inputs_format, new_inputs_type, new_outputs_format,
53 new_outputs_type, origin_processor);
54 AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
55 }
56
AddDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & clean_node,const AnfNodePtr & composite_node,const AnfNodePtr & user_node,int index) const57 void StitchAtomicCleanInserter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
58 const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
59 int index) const {
60 // Create depend node to hold execution order.
61 AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
62 auto depend_cnode = main_graph->NewCNode(d_inputs);
63 depend_cnode->set_abstract(clean_node->abstract());
64 main_graph->AddNode(depend_cnode);
65
66 auto user_cnode = user_node->cast<CNodePtr>();
67 MS_EXCEPTION_IF_NULL(user_cnode);
68 user_cnode->set_input(IntToSize(index), depend_cnode);
69 }
70
CreateAssignNode(const FuncGraphPtr & sub_graph,const AnfNodePtr & new_parameter,const InplaceAssignerInfo & info) const71 CNodePtr StitchAtomicCleanInserter::CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
72 const InplaceAssignerInfo &info) const {
73 // add assign
74 AnfNodePtr out_node = info.op_node; // Use result data itself
75
76 auto assign_node = CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, out_node}, sub_graph,
77 {GetFormat(out_node), GetShape(out_node), GetType(out_node)});
78 common::AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
79 SetNodeAttrSafely(kAttrStitch, MakeValue("common"), assign_node);
80 return assign_node;
81 }
82
ProcessOriginCNode(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr)83 void StitchAtomicCleanInserter::ProcessOriginCNode(
84 const AnfNodePtr &composite_node,
85 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr) {
86 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
87 auto mng_sub = sub_graph->manager();
88 if (mng_sub == nullptr) {
89 mng_sub = Manage(sub_graph, false);
90 sub_graph->set_manager(mng_sub);
91 }
92
93 auto [atomic_add_info, new_input] = info_and_inplace_assignee_addr[0];
94
95 // add input
96 auto inputs = composite_node->cast<CNodePtr>()->inputs();
97 inputs.push_back(new_input);
98 composite_node->cast<CNodePtr>()->set_inputs(inputs);
99
100 // add parameter
101 auto parameter = sub_graph->add_parameter();
102 parameter->set_abstract(new_input->abstract());
103 parameter->set_kernel_info(new_input->kernel_info_ptr());
104
105 auto assign = CreateAssignNode(sub_graph, parameter, atomic_add_info);
106
107 // Replace atomic ReduceSum's user with atomic clean output, and add depend op after assign to avoid
108 // elimination.
109 std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes =
110 FindInnerCNodeUsers(stitch_node_, atomic_add_info.op_node);
111 bool connected = false;
112 for (const auto &[user_node, index] : reduce_user_nodes) {
113 auto user_cnode = user_node->cast<CNodePtr>();
114 MS_EXCEPTION_IF_NULL(user_cnode);
115 user_cnode->set_input(IntToSize(index), parameter);
116 if (!connected) {
117 std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
118 if (!user_user.empty()) {
119 auto pair = user_user[0];
120 AddDepend(sub_graph, user_cnode, assign, pair.first, pair.second);
121 }
122 connected = true;
123 }
124 CorrectKernelBuildInfo(composite_node, info_and_inplace_assignee_addr);
125 }
126
127 auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
128 auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
129 sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
130 MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
131 }
132
FindInnerCNodeUsers(const AnfNodePtr & inner_node,const CNodePtr & target) const133 std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInserter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
134 const CNodePtr &target) const {
135 auto node = inner_node->cast<CNodePtr>();
136 MS_EXCEPTION_IF_NULL(node);
137 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
138 auto mng_sub = sub_graph->manager();
139 if (mng_sub == nullptr) {
140 mng_sub = Manage(sub_graph, false);
141 sub_graph->set_manager(mng_sub);
142 }
143 std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes;
144 auto users = mng_sub->node_users()[target];
145 (void)std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes),
146 [](const std::pair<AnfNodePtr, int> &pair) { return pair; });
147 return inner_user_nodes;
148 }
149
IsStitchWithAtomic(const AnfNodePtr & anf_node)150 std::pair<bool, InplaceAssignerInfo> StitchAtomicCleanInserter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
151 if (!common::AnfAlgo::IsGraphKernel(anf_node)) {
152 return {false, InplaceAssignerInfo()};
153 }
154 auto node = anf_node->cast<CNodePtr>();
155 MS_EXCEPTION_IF_NULL(node);
156 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
157 AnfNodePtrList kernel_nodes;
158 kernel::GetValidKernelNodes(sub_graph, &kernel_nodes);
159 for (auto &n : kernel_nodes) {
160 if (common::AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
161 common::AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" &&
162 IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
163 MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
164 InplaceAssignerInfo info;
165 info.op_node = n->cast<CNodePtr>();
166 stitch_node_ = anf_node;
167 return {true, info};
168 }
169 }
170 return {false, InplaceAssignerInfo()};
171 }
172
Run(const FuncGraphPtr & func_graph)173 bool StitchAtomicCleanInserter::Run(const FuncGraphPtr &func_graph) {
174 auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
175 MS_EXCEPTION_IF_NULL(kernel_graph);
176 auto mng = kernel_graph->manager();
177 if (mng == nullptr) {
178 mng = Manage(kernel_graph, true);
179 kernel_graph->set_manager(mng);
180 }
181
182 bool changed = false;
183 auto topo_nodes = TopoSort(kernel_graph->get_return());
184 for (const auto &node : topo_nodes) {
185 // if stitch attr exists, add atomic clean op depends on the attr
186 auto [is_stitch, atomic_add_info] = IsStitchWithAtomic(node);
187 if (is_stitch) {
188 InsertAtomicClean(kernel_graph, node, {atomic_add_info}, mng);
189 changed = true;
190 }
191 }
192
193 if (changed) {
194 GkUtils::UpdateFuncGraphManager(mng, func_graph);
195 }
196
197 return changed;
198 }
199 } // namespace mindspore::graphkernel
200