1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/gpu/cudnn_inplace_fusion.h"
18
19 #include <memory>
20 #include <vector>
21 #include <set>
22 #include <map>
23 #include <algorithm>
24 #include <utility>
25 #include <string>
26
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "ir/primitive.h"
29 #include "utils/utils.h"
30 #include "utils/contract.h"
31 #include "backend/optimizer/common/helper.h"
32 #include "runtime/device/gpu/kernel_info_setter.h"
33
34 namespace mindspore {
35 namespace opt {
36 namespace {
37 struct AnfNodeIndex {
AnfNodeIndexmindspore::opt::__anon7643c61f0111::AnfNodeIndex38 AnfNodeIndex() : node(nullptr), index(0) {}
AnfNodeIndexmindspore::opt::__anon7643c61f0111::AnfNodeIndex39 AnfNodeIndex(const AnfNodePtr &n, const int i) : node(n), index(i) {}
40 AnfNodePtr node;
41 uint32_t index;
42 };
43
44 // opname, output idx
45 std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, {kBatchNormGradWithAddAndActivation, 3}};
46
47 std::set<string> kSkipOpNames = {
48 kTensorAddOpName,
49 };
50
51 // opname, input idx
52 std::map<string, uint32_t> kAggregatesOpNames = {
53 {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kBatchNormGradWithAddAndActivation, 0}};
54
55 constexpr size_t inplace_node_size = 2;
56
57 template <typename T>
SetPrimAttr(AnfNodePtr inplace_node,const string & key,const T & value)58 void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) {
59 auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node);
60 MS_EXCEPTION_IF_NULL(primitive);
61 primitive->AddAttr(key, MakeValue(value));
62 }
63
64 // Check whether exist a route from src node to dst node.
ExistRoute(const CNodePtr & src,const CNodePtr & dst)65 bool ExistRoute(const CNodePtr &src, const CNodePtr &dst) {
66 MS_EXCEPTION_IF_NULL(src);
67 MS_EXCEPTION_IF_NULL(dst);
68
69 if (src == dst) {
70 return true;
71 }
72
73 size_t seen = NewSeenGeneration();
74 std::queue<CNodePtr> to_do;
75 to_do.push(dst);
76 while (!to_do.empty()) {
77 const auto ¤t_node = to_do.front();
78 size_t input_num = AnfAlgo::GetInputTensorNum(current_node);
79 for (size_t input_index = 0; input_index < input_num; ++input_index) {
80 const AnfNodePtr &input_node = AnfAlgo::GetInputNode(current_node, input_index);
81 const auto &cnode = input_node->cast<CNodePtr>();
82 if (cnode == nullptr) {
83 continue;
84 }
85 if (cnode->seen_ == seen) {
86 continue;
87 }
88 // Exist a route from src node to dst.
89 if (cnode == src) {
90 return true;
91 }
92 to_do.push(cnode);
93 cnode->seen_ = seen;
94 }
95 to_do.pop();
96 }
97 return false;
98 }
99
100 // Check whether exist a route from accumulate node to cover node.
ExistDependencyFromAcc2Cover(const std::vector<AnfNodeIndex> & inplace_node,size_t cover_index)101 bool ExistDependencyFromAcc2Cover(const std::vector<AnfNodeIndex> &inplace_node, size_t cover_index) {
102 if (inplace_node.size() != inplace_node_size) {
103 return false;
104 }
105
106 size_t acc_index = cover_index == 1 ? 0 : 1;
107 const CNodePtr &cover_node = inplace_node[cover_index].node->cast<CNodePtr>();
108 const CNodePtr &acc_node = inplace_node[acc_index].node->cast<CNodePtr>();
109 MS_EXCEPTION_IF_NULL(cover_node);
110 MS_EXCEPTION_IF_NULL(acc_node);
111 return ExistRoute(acc_node, cover_node);
112 }
113
GetCoverIndex(const std::vector<AnfNodeIndex> & inplace_node)114 std::pair<size_t, bool> GetCoverIndex(const std::vector<AnfNodeIndex> &inplace_node) {
115 if (inplace_node.size() != inplace_node_size) {
116 return {0, false};
117 }
118 auto first_node = inplace_node[0].node;
119 auto second_node = inplace_node[1].node;
120 if (AnfAlgo::GetCNodeName(first_node) != kConv2DBackpropInputOpName ||
121 AnfAlgo::GetCNodeName(second_node) != kConv2DBackpropInputOpName) {
122 return {0, false};
123 }
124
125 auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node);
126 MS_EXCEPTION_IF_NULL(first_node_prim);
127 auto first_node_channel = first_node_prim.get()->GetAttr("out_channel");
128 MS_EXCEPTION_IF_NULL(first_node_channel);
129 auto first_imm_ptr = first_node_channel->cast<Int64ImmPtr>();
130 MS_EXCEPTION_IF_NULL(first_imm_ptr);
131 size_t first_channel = first_imm_ptr->value();
132 auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node);
133 MS_EXCEPTION_IF_NULL(second_node_prim);
134 auto second_node_channel = second_node_prim.get()->GetAttr("out_channel");
135 MS_EXCEPTION_IF_NULL(second_node_channel);
136 auto second_imm_ptr = second_node_channel->cast<Int64ImmPtr>();
137 MS_EXCEPTION_IF_NULL(second_imm_ptr);
138 size_t second_channel = second_imm_ptr->value();
139 size_t cover_index = (first_channel >= second_channel) ? 0 : 1;
140 bool ret = ExistDependencyFromAcc2Cover(inplace_node, cover_index);
141 if (ret) {
142 return {0, false};
143 }
144 return {cover_index, true};
145 }
146
CopyKernelInfo(AnfNodePtr src,AnfNodePtr dst)147 void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) {
148 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(src);
149 AnfAlgo::SetSelectKernelBuildInfo(build_info, dst.get());
150 size_t output_num = AnfAlgo::GetOutputTensorNum(src);
151 std::vector<TypeId> types;
152 std::vector<std::vector<size_t>> shapes;
153 for (size_t i = 0; i < output_num; i++) {
154 types.emplace_back(AnfAlgo::GetOutputInferDataType(src, i));
155 shapes.emplace_back(AnfAlgo::GetOutputInferShape(src, i));
156 }
157 AnfAlgo::SetOutputInferTypeAndShape(types, shapes, dst.get());
158 }
159
CheckInplaceNodeInputs(std::vector<AnfNodeIndex> * inplace_node,size_t cover_index,const FuncGraphPtr & graph)160 void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cover_index, const FuncGraphPtr &graph) {
161 // If two inplace nodes have same input, will be have loop after insert depend:
162 // A A Cover <----+
163 // / \ \ / |
164 // B \ --> Depend -------> B
165 // / \ |
166 // Cover Acc Acc
167 // so copy a new input for one of inplace node like this
168 // A A' A A'
169 // | | | |
170 // B | --> B Depend <-+
171 // | | | | |
172 // Cover Acc | Acc |
173 // Cover---------------+
174 MS_EXCEPTION_IF_NULL(inplace_node);
175 MS_EXCEPTION_IF_NULL(graph);
176 size_t acc_index = cover_index == 1 ? 0 : 1;
177 const CNodePtr &cover_node = inplace_node->at(cover_index).node->cast<CNodePtr>();
178 const CNodePtr &acc_node = inplace_node->at(acc_index).node->cast<CNodePtr>();
179 MS_EXCEPTION_IF_NULL(cover_node);
180 MS_EXCEPTION_IF_NULL(acc_node);
181 const auto &acc_input = acc_node->input(1)->cast<CNodePtr>();
182 if (acc_input == nullptr) {
183 return;
184 }
185 bool ret = ExistRoute(acc_input, cover_node);
186 if (ret) {
187 auto new_input = graph->NewCNode(acc_input->inputs());
188 MS_EXCEPTION_IF_NULL(new_input);
189 new_input->set_abstract(acc_input->abstract());
190 CopyKernelInfo(acc_input, new_input);
191 auto new_inplace_node = graph->NewCNode({acc_node->input(0), new_input, acc_node->input(2)});
192 MS_EXCEPTION_IF_NULL(new_inplace_node);
193 new_inplace_node->set_abstract(acc_node->abstract());
194 CopyKernelInfo(acc_node, new_inplace_node);
195 auto manager = graph->manager();
196 MS_EXCEPTION_IF_NULL(manager);
197 manager->Replace(acc_node, new_inplace_node);
198 (*inplace_node)[acc_index].node = new_inplace_node;
199 }
200 }
201
SetNodeAttr(AnfNodeIndex aggregate_node,AnfNodePtr skip_node,std::vector<AnfNodeIndex> * inplace_node,const FuncGraphPtr & graph)202 void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node,
203 const FuncGraphPtr &graph) {
204 MS_EXCEPTION_IF_NULL(skip_node);
205 MS_EXCEPTION_IF_NULL(inplace_node);
206 MS_EXCEPTION_IF_NULL(graph);
207
208 SetPrimAttr(aggregate_node.node, "aggregate", true);
209 SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index);
210 SetPrimAttr(skip_node, "skip", true);
211
212 static uint32_t group = 0;
213 auto [cover_index, order_required] = GetCoverIndex(*inplace_node);
214 CheckInplaceNodeInputs(inplace_node, cover_index, graph);
215
216 for (size_t i = 0; i < inplace_node->size(); i++) {
217 auto algo = (i == cover_index) ? "cover" : "accumulation";
218 auto node = (*inplace_node)[i].node;
219 MS_EXCEPTION_IF_NULL(node);
220 SetPrimAttr(node, "inplace_algo", algo);
221 SetPrimAttr(node, "inplace_group", group);
222 SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index);
223 // for Conv2DBackpropInputOp, need insert depend node to keep order, set the larger channel to cover
224 if (order_required && i != cover_index) {
225 auto acc_node = node;
226 auto cover_node = (*inplace_node)[cover_index].node;
227 auto cnode = acc_node->cast<CNodePtr>();
228 MS_EXCEPTION_IF_NULL(cnode);
229 auto acc_node_input = cnode->input(1);
230 std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
231 acc_node_input, cover_node};
232 auto depend_node = graph->NewCNode(inputs);
233 MS_EXCEPTION_IF_NULL(depend_node);
234 depend_node->set_abstract(acc_node_input->abstract());
235 auto manager = graph->manager();
236 MS_EXCEPTION_IF_NULL(manager);
237 manager->Replace(acc_node_input, depend_node);
238 }
239 }
240 group++;
241 }
242
PatternMatch(const FuncGraphPtr & graph,const AnfNodePtr & node,AnfNodeIndex * aggregate,AnfNodePtr * skip_node,std::vector<AnfNodeIndex> * inplace)243 bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
244 std::vector<AnfNodeIndex> *inplace) {
245 MS_EXCEPTION_IF_NULL(graph);
246 MS_EXCEPTION_IF_NULL(node);
247 MS_EXCEPTION_IF_NULL(inplace);
248 MS_EXCEPTION_IF_NULL(skip_node);
249 MS_EXCEPTION_IF_NULL(aggregate);
250 if (!node->isa<CNode>()) {
251 return false;
252 }
253 auto aggregate_iter = kAggregatesOpNames.find(AnfAlgo::GetCNodeName(node));
254 if (aggregate_iter == kAggregatesOpNames.end()) {
255 return false;
256 }
257 aggregate->node = node;
258 aggregate->index = aggregate_iter->second;
259
260 *skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), aggregate_iter->second);
261 if (*skip_node == nullptr || !(*skip_node)->isa<CNode>() ||
262 kSkipOpNames.count(AnfAlgo::GetCNodeName(*skip_node)) == 0 ||
263 GetRealNodeUsedList(graph, *skip_node)->size() >= 2) {
264 return false;
265 }
266
267 auto cnode = (*skip_node)->cast<CNodePtr>();
268 MS_EXCEPTION_IF_NULL(cnode);
269 size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
270 for (size_t i = 0; i < input_num; i++) {
271 auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i);
272 if (!inplace_node->isa<CNode>()) {
273 return false;
274 }
275 // Check Inplace nodes have no user except TensorAdd nodes
276 if (GetRealNodeUsedList(graph, inplace_node)->size() >= 2) {
277 return false;
278 }
279
280 // skip TupleGetItem node
281 if (AnfAlgo::GetCNodeName(inplace_node) == prim::kPrimTupleGetItem->name()) {
282 inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(inplace_node), 0);
283 }
284
285 auto inplace_iter = kInplaceOpNames.find(AnfAlgo::GetCNodeName(inplace_node));
286 if (inplace_iter == kInplaceOpNames.end()) {
287 return false;
288 }
289
290 inplace->push_back(AnfNodeIndex(inplace_node, inplace_iter->second));
291 }
292
293 return true;
294 }
295
TopoIndex(const std::vector<AnfNodePtr> & node_list)296 std::map<AnfNodePtr, int> TopoIndex(const std::vector<AnfNodePtr> &node_list) {
297 std::map<AnfNodePtr, int> topo_index;
298 for (size_t i = 0; i < node_list.size(); i++) {
299 topo_index.insert(make_pair(node_list[i], i));
300 }
301 return topo_index;
302 }
303 } // namespace
304
Run(const FuncGraphPtr & graph)305 bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
306 MS_EXCEPTION_IF_NULL(graph);
307 std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
308 auto topo_index = TopoIndex(node_list);
309
310 for (auto node : node_list) {
311 AnfNodeIndex aggregate_node;
312 AnfNodePtr skip_node;
313 std::vector<AnfNodeIndex> inplace_node;
314 // 1. Pattern Match.
315 if (!PatternMatch(graph, node, &aggregate_node, &skip_node, &inplace_node)) {
316 continue;
317 }
318
319 // 2. Keep the original topological order in case the dependence between inplace nodes
320 std::sort(inplace_node.begin(), inplace_node.end(), [&topo_index](const AnfNodeIndex &n1, const AnfNodeIndex &n2) {
321 auto iter1 = topo_index.find(n1.node);
322 auto iter2 = topo_index.find(n2.node);
323 if (iter1 == topo_index.end() || iter2 == topo_index.end()) {
324 MS_LOG(EXCEPTION) << ": Node not existed in topo order. node1: " << n1.node->DebugString()
325 << ", node2: " << n2.node->DebugString();
326 }
327
328 if (iter1->second < iter2->second) {
329 return true;
330 }
331 return false;
332 });
333 MS_LOG(INFO) << "[inplace optimizer] aggregate node: " << aggregate_node.index << ", "
334 << aggregate_node.node->DebugString() << "; skip node: " << skip_node->DebugString() << std::endl
335 << "; inplace node 0: " << inplace_node[0].index << ", " << inplace_node[0].node->DebugString()
336 << std::endl
337 << "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString()
338 << std::endl;
339 // 2. Set Node attr
340 SetNodeAttr(aggregate_node, skip_node, &inplace_node, graph);
341 }
342
343 return true;
344 }
345 } // namespace opt
346 } // namespace mindspore
347