1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
18
19 #include <algorithm>
20 #include <string>
21 #include "base/core_ops.h"
22 #include "ir/tensor.h"
23 #include "utils/utils.h"
24 #include "utils/log_adapter.h"
25 #include "backend/kernel_compiler/kernel.h"
26 #include "backend/kernel_compiler/common_utils.h"
27 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "backend/session/kernel_graph.h"
30
31 namespace mindspore {
32 namespace opt {
CreateInplaceAssignNode(const FuncGraphPtr & sub_graph,const AnfNodePtr & new_parameter) const33 CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
34 const AnfNodePtr &new_parameter) const {
35 // add inplaceassign
36 AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
37 auto inplace_assign_node =
38 CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
39 {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
40 SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node);
41 AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_);
42 SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node);
43 return inplace_assign_node;
44 }
45
ProcessOriginCNode(const AnfNodePtr & composite_node,const AnfNodePtr & new_input)46 void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
47 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
48 auto mng_sub = sub_graph->manager();
49 if (mng_sub == nullptr) {
50 mng_sub = Manage(sub_graph, false);
51 sub_graph->set_manager(mng_sub);
52 }
53
54 // add input
55 auto inputs = composite_node->cast<CNodePtr>()->inputs();
56 inputs.push_back(new_input);
57 composite_node->cast<CNodePtr>()->set_inputs(inputs);
58
59 // add parameter
60 auto parameter = sub_graph->add_parameter();
61 parameter->set_abstract(new_input->abstract());
62 parameter->set_kernel_info(new_input->kernel_info_ptr());
63
64 auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter);
65
66 // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
67 // elimination.
68 std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_);
69 bool connected = false;
70 for (const auto &[user_node, index] : reduce_user_nodes) {
71 auto user_cnode = user_node->cast<CNodePtr>();
72 MS_EXCEPTION_IF_NULL(user_cnode);
73 user_cnode->set_input(static_cast<size_t>(index), parameter);
74 if (!connected) {
75 std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
76 if (!user_user.empty()) {
77 auto pair = user_user[0];
78 AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second);
79 }
80 connected = true;
81 }
82 CorrectKernelBuildInfo(composite_node, new_input, false);
83 }
84
85 auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
86 auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
87 sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
88 MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
89 }
90
FindInnerCNodeUsers(const AnfNodePtr & inner_node,const CNodePtr & target) const91 std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
92 const CNodePtr &target) const {
93 auto node = inner_node->cast<CNodePtr>();
94 MS_EXCEPTION_IF_NULL(node);
95 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
96 auto mng_sub = sub_graph->manager();
97 if (mng_sub == nullptr) {
98 mng_sub = Manage(sub_graph, false);
99 sub_graph->set_manager(mng_sub);
100 }
101 std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes;
102 auto users = mng_sub->node_users()[target];
103 std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes),
104 [](const std::pair<AnfNodePtr, int> &pair) { return pair; });
105 return inner_user_nodes;
106 }
107
IsStitchWithAtomic(const AnfNodePtr & anf_node)108 bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
109 if (!AnfAlgo::IsGraphKernel(anf_node)) return false;
110 auto node = anf_node->cast<CNodePtr>();
111 MS_EXCEPTION_IF_NULL(node);
112 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
113 AnfNodePtrList kernel_nodes;
114 kernel::GetValidKernelNodes(sub_graph, &kernel_nodes);
115 for (auto &n : kernel_nodes) {
116 if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
117 AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
118 MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
119 atomic_add_node_ = n->cast<CNodePtr>();
120 stitch_node_ = anf_node;
121 return true;
122 }
123 }
124 return false;
125 }
126
Run(const FuncGraphPtr & func_graph)127 bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
128 auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
129 MS_EXCEPTION_IF_NULL(kernel_graph);
130 auto mng = kernel_graph->manager();
131 if (mng == nullptr) {
132 mng = Manage(kernel_graph, true);
133 kernel_graph->set_manager(mng);
134 }
135
136 bool changed = false;
137 auto topo_nodes = TopoSort(kernel_graph->get_return());
138 for (const auto &node : topo_nodes) {
139 // if stitch attr exists, add atomic clean op depends on the attr
140 if (IsStitchWithAtomic(node)) {
141 InsertAtomicClean(kernel_graph, node, mng);
142 changed = true;
143 }
144 }
145
146 if (changed) {
147 UpdateMng(mng, func_graph);
148 }
149
150 return changed;
151 }
152 } // namespace opt
153 } // namespace mindspore
154