• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/add_stitch_atomic_clean_gpu.h"
17 
18 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
19 #include "backend/common/graph_kernel/graph_kernel_helper.h"
20 #include "include/backend/kernel_graph.h"
21 #include "include/common/utils/utils.h"
22 #include "kernel/framework_utils.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "utils/log_adapter.h"
27 
28 namespace mindspore::graphkernel {
CorrectKernelBuildInfo(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & inplace_infos)29 void StitchAtomicCleanInserter::CorrectKernelBuildInfo(
30   const AnfNodePtr &composite_node, const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &inplace_infos) {
31   // Change kernel build info.
32   auto kernel_info = dynamic_cast<device::KernelInfo *>(composite_node->kernel_info());
33   MS_EXCEPTION_IF_NULL(kernel_info);
34   const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
35   auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
36   auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
37   auto origin_processor = origin_kernel_build_info->processor();
38 
39   std::vector<std::string> new_inputs_format = origin_kernel_build_info->GetAllInputFormats();
40   std::vector<TypeId> new_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
41   std::vector<std::string> new_outputs_format;
42   std::vector<TypeId> new_outputs_type;
43   for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
44     new_outputs_format.push_back(origin_outputs_format[i]);
45     new_outputs_type.push_back(origin_outputs_type[i]);
46   }
47 
48   auto kernel_with_index = common::AnfAlgo::VisitKernel(inplace_infos[0].second, 0);
49   new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
50   new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
51 
52   auto new_selected_info = BuildSelectKernelBuildInfo(new_inputs_format, new_inputs_type, new_outputs_format,
53                                                       new_outputs_type, origin_processor);
54   AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
55 }
56 
AddDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & clean_node,const AnfNodePtr & composite_node,const AnfNodePtr & user_node,int index) const57 void StitchAtomicCleanInserter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
58                                           const AnfNodePtr &composite_node, const AnfNodePtr &user_node,
59                                           int index) const {
60   // Create depend node to hold execution order.
61   AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
62   auto depend_cnode = main_graph->NewCNode(d_inputs);
63   depend_cnode->set_abstract(clean_node->abstract());
64   main_graph->AddNode(depend_cnode);
65 
66   auto user_cnode = user_node->cast<CNodePtr>();
67   MS_EXCEPTION_IF_NULL(user_cnode);
68   user_cnode->set_input(IntToSize(index), depend_cnode);
69 }
70 
CreateAssignNode(const FuncGraphPtr & sub_graph,const AnfNodePtr & new_parameter,const InplaceAssignerInfo & info) const71 CNodePtr StitchAtomicCleanInserter::CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
72                                                      const InplaceAssignerInfo &info) const {
73   // add assign
74   AnfNodePtr out_node = info.op_node;  // Use result data itself
75 
76   auto assign_node = CreateCNode({NewValueNode(prim::kPrimAssign), new_parameter, out_node}, sub_graph,
77                                  {GetFormat(out_node), GetShape(out_node), GetType(out_node)});
78   common::AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
79   SetNodeAttrSafely(kAttrStitch, MakeValue("common"), assign_node);
80   return assign_node;
81 }
82 
ProcessOriginCNode(const AnfNodePtr & composite_node,const std::vector<std::pair<InplaceAssignerInfo,AnfNodePtr>> & info_and_inplace_assignee_addr)83 void StitchAtomicCleanInserter::ProcessOriginCNode(
84   const AnfNodePtr &composite_node,
85   const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr) {
86   auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
87   auto mng_sub = sub_graph->manager();
88   if (mng_sub == nullptr) {
89     mng_sub = Manage(sub_graph, false);
90     sub_graph->set_manager(mng_sub);
91   }
92 
93   auto [atomic_add_info, new_input] = info_and_inplace_assignee_addr[0];
94 
95   // add input
96   auto inputs = composite_node->cast<CNodePtr>()->inputs();
97   inputs.push_back(new_input);
98   composite_node->cast<CNodePtr>()->set_inputs(inputs);
99 
100   // add parameter
101   auto parameter = sub_graph->add_parameter();
102   parameter->set_abstract(new_input->abstract());
103   parameter->set_kernel_info(new_input->kernel_info_ptr());
104 
105   auto assign = CreateAssignNode(sub_graph, parameter, atomic_add_info);
106 
107   // Replace atomic ReduceSum's user with atomic clean output, and add depend op after assign to avoid
108   // elimination.
109   std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes =
110     FindInnerCNodeUsers(stitch_node_, atomic_add_info.op_node);
111   bool connected = false;
112   for (const auto &[user_node, index] : reduce_user_nodes) {
113     auto user_cnode = user_node->cast<CNodePtr>();
114     MS_EXCEPTION_IF_NULL(user_cnode);
115     user_cnode->set_input(IntToSize(index), parameter);
116     if (!connected) {
117       std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
118       if (!user_user.empty()) {
119         auto pair = user_user[0];
120         AddDepend(sub_graph, user_cnode, assign, pair.first, pair.second);
121       }
122       connected = true;
123     }
124     CorrectKernelBuildInfo(composite_node, info_and_inplace_assignee_addr);
125   }
126 
127   auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
128   auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
129   sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
130   MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
131 }
132 
FindInnerCNodeUsers(const AnfNodePtr & inner_node,const CNodePtr & target) const133 std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInserter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
134                                                                                        const CNodePtr &target) const {
135   auto node = inner_node->cast<CNodePtr>();
136   MS_EXCEPTION_IF_NULL(node);
137   auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
138   auto mng_sub = sub_graph->manager();
139   if (mng_sub == nullptr) {
140     mng_sub = Manage(sub_graph, false);
141     sub_graph->set_manager(mng_sub);
142   }
143   std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes;
144   auto users = mng_sub->node_users()[target];
145   (void)std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes),
146                        [](const std::pair<AnfNodePtr, int> &pair) { return pair; });
147   return inner_user_nodes;
148 }
149 
IsStitchWithAtomic(const AnfNodePtr & anf_node)150 std::pair<bool, InplaceAssignerInfo> StitchAtomicCleanInserter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
151   if (!common::AnfAlgo::IsGraphKernel(anf_node)) {
152     return {false, InplaceAssignerInfo()};
153   }
154   auto node = anf_node->cast<CNodePtr>();
155   MS_EXCEPTION_IF_NULL(node);
156   auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
157   AnfNodePtrList kernel_nodes;
158   kernel::GetValidKernelNodes(sub_graph, &kernel_nodes);
159   for (auto &n : kernel_nodes) {
160     if (common::AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
161         common::AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" &&
162         IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
163       MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
164       InplaceAssignerInfo info;
165       info.op_node = n->cast<CNodePtr>();
166       stitch_node_ = anf_node;
167       return {true, info};
168     }
169   }
170   return {false, InplaceAssignerInfo()};
171 }
172 
Run(const FuncGraphPtr & func_graph)173 bool StitchAtomicCleanInserter::Run(const FuncGraphPtr &func_graph) {
174   auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
175   MS_EXCEPTION_IF_NULL(kernel_graph);
176   auto mng = kernel_graph->manager();
177   if (mng == nullptr) {
178     mng = Manage(kernel_graph, true);
179     kernel_graph->set_manager(mng);
180   }
181 
182   bool changed = false;
183   auto topo_nodes = TopoSort(kernel_graph->get_return());
184   for (const auto &node : topo_nodes) {
185     // if stitch attr exists, add atomic clean op depends on the attr
186     auto [is_stitch, atomic_add_info] = IsStitchWithAtomic(node);
187     if (is_stitch) {
188       InsertAtomicClean(kernel_graph, node, {atomic_add_info}, mng);
189       changed = true;
190     }
191   }
192 
193   if (changed) {
194     GkUtils::UpdateFuncGraphManager(mng, func_graph);
195   }
196 
197   return changed;
198 }
199 }  // namespace mindspore::graphkernel
200