• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
18 
19 #include <algorithm>
20 #include <string>
21 #include "base/core_ops.h"
22 #include "ir/tensor.h"
23 #include "utils/utils.h"
24 #include "utils/log_adapter.h"
25 #include "backend/kernel_compiler/kernel.h"
26 #include "backend/kernel_compiler/common_utils.h"
27 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "backend/session/kernel_graph.h"
30 
31 namespace mindspore {
32 namespace opt {
CreateInplaceAssignNode(const FuncGraphPtr & sub_graph,const AnfNodePtr & new_parameter) const33 CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr &sub_graph,
34                                                              const AnfNodePtr &new_parameter) const {
35   // add inplaceassign
36   AnfNodePtr out_node = atomic_add_node_;  // Use result data itself, and set attr "fake_out" true.
37   auto inplace_assign_node =
38     CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, atomic_add_node_, out_node}, sub_graph,
39                 {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
40   SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node);
41   AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_);
42   SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node);
43   return inplace_assign_node;
44 }
45 
ProcessOriginCNode(const AnfNodePtr & composite_node,const AnfNodePtr & new_input)46 void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
47   auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
48   auto mng_sub = sub_graph->manager();
49   if (mng_sub == nullptr) {
50     mng_sub = Manage(sub_graph, false);
51     sub_graph->set_manager(mng_sub);
52   }
53 
54   // add input
55   auto inputs = composite_node->cast<CNodePtr>()->inputs();
56   inputs.push_back(new_input);
57   composite_node->cast<CNodePtr>()->set_inputs(inputs);
58 
59   // add parameter
60   auto parameter = sub_graph->add_parameter();
61   parameter->set_abstract(new_input->abstract());
62   parameter->set_kernel_info(new_input->kernel_info_ptr());
63 
64   auto inplace_assign = CreateInplaceAssignNode(sub_graph, parameter);
65 
66   // Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
67   // elimination.
68   std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_);
69   bool connected = false;
70   for (const auto &[user_node, index] : reduce_user_nodes) {
71     auto user_cnode = user_node->cast<CNodePtr>();
72     MS_EXCEPTION_IF_NULL(user_cnode);
73     user_cnode->set_input(static_cast<size_t>(index), parameter);
74     if (!connected) {
75       std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
76       if (!user_user.empty()) {
77         auto pair = user_user[0];
78         AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second);
79       }
80       connected = true;
81     }
82     CorrectKernelBuildInfo(composite_node, new_input, false);
83   }
84 
85   auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
86   auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
87   sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
88   MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
89 }
90 
FindInnerCNodeUsers(const AnfNodePtr & inner_node,const CNodePtr & target) const91 std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
92                                                                                         const CNodePtr &target) const {
93   auto node = inner_node->cast<CNodePtr>();
94   MS_EXCEPTION_IF_NULL(node);
95   auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
96   auto mng_sub = sub_graph->manager();
97   if (mng_sub == nullptr) {
98     mng_sub = Manage(sub_graph, false);
99     sub_graph->set_manager(mng_sub);
100   }
101   std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes;
102   auto users = mng_sub->node_users()[target];
103   std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes),
104                  [](const std::pair<AnfNodePtr, int> &pair) { return pair; });
105   return inner_user_nodes;
106 }
107 
IsStitchWithAtomic(const AnfNodePtr & anf_node)108 bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
109   if (!AnfAlgo::IsGraphKernel(anf_node)) return false;
110   auto node = anf_node->cast<CNodePtr>();
111   MS_EXCEPTION_IF_NULL(node);
112   auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
113   AnfNodePtrList kernel_nodes;
114   kernel::GetValidKernelNodes(sub_graph, &kernel_nodes);
115   for (auto &n : kernel_nodes) {
116     if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
117         AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
118       MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
119       atomic_add_node_ = n->cast<CNodePtr>();
120       stitch_node_ = anf_node;
121       return true;
122     }
123   }
124   return false;
125 }
126 
Run(const FuncGraphPtr & func_graph)127 bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
128   auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
129   MS_EXCEPTION_IF_NULL(kernel_graph);
130   auto mng = kernel_graph->manager();
131   if (mng == nullptr) {
132     mng = Manage(kernel_graph, true);
133     kernel_graph->set_manager(mng);
134   }
135 
136   bool changed = false;
137   auto topo_nodes = TopoSort(kernel_graph->get_return());
138   for (const auto &node : topo_nodes) {
139     // if stitch attr exists, add atomic clean op depends on the attr
140     if (IsStitchWithAtomic(node)) {
141       InsertAtomicClean(kernel_graph, node, mng);
142       changed = true;
143     }
144   }
145 
146   if (changed) {
147     UpdateMng(mng, func_graph);
148   }
149 
150   return changed;
151 }
152 }  // namespace opt
153 }  // namespace mindspore
154