• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/gpu/cudnn_inplace_fusion.h"
18 
19 #include <memory>
20 #include <vector>
21 #include <set>
22 #include <map>
23 #include <algorithm>
24 #include <utility>
25 #include <string>
26 
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "ir/primitive.h"
29 #include "utils/utils.h"
30 #include "utils/contract.h"
31 #include "backend/optimizer/common/helper.h"
32 #include "runtime/device/gpu/kernel_info_setter.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace {
37 struct AnfNodeIndex {
AnfNodeIndexmindspore::opt::__anon7643c61f0111::AnfNodeIndex38   AnfNodeIndex() : node(nullptr), index(0) {}
AnfNodeIndexmindspore::opt::__anon7643c61f0111::AnfNodeIndex39   AnfNodeIndex(const AnfNodePtr &n, const int i) : node(n), index(i) {}
40   AnfNodePtr node;
41   uint32_t index;
42 };
43 
44 // opname, output idx
45 std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, {kBatchNormGradWithAddAndActivation, 3}};
46 
47 std::set<string> kSkipOpNames = {
48   kTensorAddOpName,
49 };
50 
51 // opname, input idx
52 std::map<string, uint32_t> kAggregatesOpNames = {
53   {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kBatchNormGradWithAddAndActivation, 0}};
54 
55 constexpr size_t inplace_node_size = 2;
56 
57 template <typename T>
SetPrimAttr(AnfNodePtr inplace_node,const string & key,const T & value)58 void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) {
59   auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node);
60   MS_EXCEPTION_IF_NULL(primitive);
61   primitive->AddAttr(key, MakeValue(value));
62 }
63 
64 // Check whether exist a route from src node to dst node.
ExistRoute(const CNodePtr & src,const CNodePtr & dst)65 bool ExistRoute(const CNodePtr &src, const CNodePtr &dst) {
66   MS_EXCEPTION_IF_NULL(src);
67   MS_EXCEPTION_IF_NULL(dst);
68 
69   if (src == dst) {
70     return true;
71   }
72 
73   size_t seen = NewSeenGeneration();
74   std::queue<CNodePtr> to_do;
75   to_do.push(dst);
76   while (!to_do.empty()) {
77     const auto &current_node = to_do.front();
78     size_t input_num = AnfAlgo::GetInputTensorNum(current_node);
79     for (size_t input_index = 0; input_index < input_num; ++input_index) {
80       const AnfNodePtr &input_node = AnfAlgo::GetInputNode(current_node, input_index);
81       const auto &cnode = input_node->cast<CNodePtr>();
82       if (cnode == nullptr) {
83         continue;
84       }
85       if (cnode->seen_ == seen) {
86         continue;
87       }
88       // Exist a route from src node to dst.
89       if (cnode == src) {
90         return true;
91       }
92       to_do.push(cnode);
93       cnode->seen_ = seen;
94     }
95     to_do.pop();
96   }
97   return false;
98 }
99 
100 // Check whether exist a route from accumulate node to cover node.
ExistDependencyFromAcc2Cover(const std::vector<AnfNodeIndex> & inplace_node,size_t cover_index)101 bool ExistDependencyFromAcc2Cover(const std::vector<AnfNodeIndex> &inplace_node, size_t cover_index) {
102   if (inplace_node.size() != inplace_node_size) {
103     return false;
104   }
105 
106   size_t acc_index = cover_index == 1 ? 0 : 1;
107   const CNodePtr &cover_node = inplace_node[cover_index].node->cast<CNodePtr>();
108   const CNodePtr &acc_node = inplace_node[acc_index].node->cast<CNodePtr>();
109   MS_EXCEPTION_IF_NULL(cover_node);
110   MS_EXCEPTION_IF_NULL(acc_node);
111   return ExistRoute(acc_node, cover_node);
112 }
113 
GetCoverIndex(const std::vector<AnfNodeIndex> & inplace_node)114 std::pair<size_t, bool> GetCoverIndex(const std::vector<AnfNodeIndex> &inplace_node) {
115   if (inplace_node.size() != inplace_node_size) {
116     return {0, false};
117   }
118   auto first_node = inplace_node[0].node;
119   auto second_node = inplace_node[1].node;
120   if (AnfAlgo::GetCNodeName(first_node) != kConv2DBackpropInputOpName ||
121       AnfAlgo::GetCNodeName(second_node) != kConv2DBackpropInputOpName) {
122     return {0, false};
123   }
124 
125   auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node);
126   MS_EXCEPTION_IF_NULL(first_node_prim);
127   auto first_node_channel = first_node_prim.get()->GetAttr("out_channel");
128   MS_EXCEPTION_IF_NULL(first_node_channel);
129   auto first_imm_ptr = first_node_channel->cast<Int64ImmPtr>();
130   MS_EXCEPTION_IF_NULL(first_imm_ptr);
131   size_t first_channel = first_imm_ptr->value();
132   auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node);
133   MS_EXCEPTION_IF_NULL(second_node_prim);
134   auto second_node_channel = second_node_prim.get()->GetAttr("out_channel");
135   MS_EXCEPTION_IF_NULL(second_node_channel);
136   auto second_imm_ptr = second_node_channel->cast<Int64ImmPtr>();
137   MS_EXCEPTION_IF_NULL(second_imm_ptr);
138   size_t second_channel = second_imm_ptr->value();
139   size_t cover_index = (first_channel >= second_channel) ? 0 : 1;
140   bool ret = ExistDependencyFromAcc2Cover(inplace_node, cover_index);
141   if (ret) {
142     return {0, false};
143   }
144   return {cover_index, true};
145 }
146 
CopyKernelInfo(AnfNodePtr src,AnfNodePtr dst)147 void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) {
148   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(src);
149   AnfAlgo::SetSelectKernelBuildInfo(build_info, dst.get());
150   size_t output_num = AnfAlgo::GetOutputTensorNum(src);
151   std::vector<TypeId> types;
152   std::vector<std::vector<size_t>> shapes;
153   for (size_t i = 0; i < output_num; i++) {
154     types.emplace_back(AnfAlgo::GetOutputInferDataType(src, i));
155     shapes.emplace_back(AnfAlgo::GetOutputInferShape(src, i));
156   }
157   AnfAlgo::SetOutputInferTypeAndShape(types, shapes, dst.get());
158 }
159 
CheckInplaceNodeInputs(std::vector<AnfNodeIndex> * inplace_node,size_t cover_index,const FuncGraphPtr & graph)160 void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cover_index, const FuncGraphPtr &graph) {
161   // If two inplace nodes have same input, will be have loop after insert depend:
162   //            A                              A     Cover <----+
163   //          /    \                            \    /          |
164   //         B      \            -->            Depend -------> B
165   //        /        \                            |
166   //      Cover      Acc                         Acc
167   // so copy a new input for one of inplace node like this
168   //        A         A'                          A           A'
169   //        |         |                           |           |
170   //        B         |          -->              B        Depend <-+
171   //        |         |                           |           |     |
172   //      Cover      Acc                          |          Acc    |
173   //                                            Cover---------------+
174   MS_EXCEPTION_IF_NULL(inplace_node);
175   MS_EXCEPTION_IF_NULL(graph);
176   size_t acc_index = cover_index == 1 ? 0 : 1;
177   const CNodePtr &cover_node = inplace_node->at(cover_index).node->cast<CNodePtr>();
178   const CNodePtr &acc_node = inplace_node->at(acc_index).node->cast<CNodePtr>();
179   MS_EXCEPTION_IF_NULL(cover_node);
180   MS_EXCEPTION_IF_NULL(acc_node);
181   const auto &acc_input = acc_node->input(1)->cast<CNodePtr>();
182   if (acc_input == nullptr) {
183     return;
184   }
185   bool ret = ExistRoute(acc_input, cover_node);
186   if (ret) {
187     auto new_input = graph->NewCNode(acc_input->inputs());
188     MS_EXCEPTION_IF_NULL(new_input);
189     new_input->set_abstract(acc_input->abstract());
190     CopyKernelInfo(acc_input, new_input);
191     auto new_inplace_node = graph->NewCNode({acc_node->input(0), new_input, acc_node->input(2)});
192     MS_EXCEPTION_IF_NULL(new_inplace_node);
193     new_inplace_node->set_abstract(acc_node->abstract());
194     CopyKernelInfo(acc_node, new_inplace_node);
195     auto manager = graph->manager();
196     MS_EXCEPTION_IF_NULL(manager);
197     manager->Replace(acc_node, new_inplace_node);
198     (*inplace_node)[acc_index].node = new_inplace_node;
199   }
200 }
201 
SetNodeAttr(AnfNodeIndex aggregate_node,AnfNodePtr skip_node,std::vector<AnfNodeIndex> * inplace_node,const FuncGraphPtr & graph)202 void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node,
203                  const FuncGraphPtr &graph) {
204   MS_EXCEPTION_IF_NULL(skip_node);
205   MS_EXCEPTION_IF_NULL(inplace_node);
206   MS_EXCEPTION_IF_NULL(graph);
207 
208   SetPrimAttr(aggregate_node.node, "aggregate", true);
209   SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index);
210   SetPrimAttr(skip_node, "skip", true);
211 
212   static uint32_t group = 0;
213   auto [cover_index, order_required] = GetCoverIndex(*inplace_node);
214   CheckInplaceNodeInputs(inplace_node, cover_index, graph);
215 
216   for (size_t i = 0; i < inplace_node->size(); i++) {
217     auto algo = (i == cover_index) ? "cover" : "accumulation";
218     auto node = (*inplace_node)[i].node;
219     MS_EXCEPTION_IF_NULL(node);
220     SetPrimAttr(node, "inplace_algo", algo);
221     SetPrimAttr(node, "inplace_group", group);
222     SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index);
223     // for Conv2DBackpropInputOp, need insert depend node to keep order, set the larger channel to cover
224     if (order_required && i != cover_index) {
225       auto acc_node = node;
226       auto cover_node = (*inplace_node)[cover_index].node;
227       auto cnode = acc_node->cast<CNodePtr>();
228       MS_EXCEPTION_IF_NULL(cnode);
229       auto acc_node_input = cnode->input(1);
230       std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
231                                         acc_node_input, cover_node};
232       auto depend_node = graph->NewCNode(inputs);
233       MS_EXCEPTION_IF_NULL(depend_node);
234       depend_node->set_abstract(acc_node_input->abstract());
235       auto manager = graph->manager();
236       MS_EXCEPTION_IF_NULL(manager);
237       manager->Replace(acc_node_input, depend_node);
238     }
239   }
240   group++;
241 }
242 
PatternMatch(const FuncGraphPtr & graph,const AnfNodePtr & node,AnfNodeIndex * aggregate,AnfNodePtr * skip_node,std::vector<AnfNodeIndex> * inplace)243 bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
244                   std::vector<AnfNodeIndex> *inplace) {
245   MS_EXCEPTION_IF_NULL(graph);
246   MS_EXCEPTION_IF_NULL(node);
247   MS_EXCEPTION_IF_NULL(inplace);
248   MS_EXCEPTION_IF_NULL(skip_node);
249   MS_EXCEPTION_IF_NULL(aggregate);
250   if (!node->isa<CNode>()) {
251     return false;
252   }
253   auto aggregate_iter = kAggregatesOpNames.find(AnfAlgo::GetCNodeName(node));
254   if (aggregate_iter == kAggregatesOpNames.end()) {
255     return false;
256   }
257   aggregate->node = node;
258   aggregate->index = aggregate_iter->second;
259 
260   *skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), aggregate_iter->second);
261   if (*skip_node == nullptr || !(*skip_node)->isa<CNode>() ||
262       kSkipOpNames.count(AnfAlgo::GetCNodeName(*skip_node)) == 0 ||
263       GetRealNodeUsedList(graph, *skip_node)->size() >= 2) {
264     return false;
265   }
266 
267   auto cnode = (*skip_node)->cast<CNodePtr>();
268   MS_EXCEPTION_IF_NULL(cnode);
269   size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
270   for (size_t i = 0; i < input_num; i++) {
271     auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i);
272     if (!inplace_node->isa<CNode>()) {
273       return false;
274     }
275     // Check Inplace nodes have no user except TensorAdd nodes
276     if (GetRealNodeUsedList(graph, inplace_node)->size() >= 2) {
277       return false;
278     }
279 
280     // skip TupleGetItem node
281     if (AnfAlgo::GetCNodeName(inplace_node) == prim::kPrimTupleGetItem->name()) {
282       inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(inplace_node), 0);
283     }
284 
285     auto inplace_iter = kInplaceOpNames.find(AnfAlgo::GetCNodeName(inplace_node));
286     if (inplace_iter == kInplaceOpNames.end()) {
287       return false;
288     }
289 
290     inplace->push_back(AnfNodeIndex(inplace_node, inplace_iter->second));
291   }
292 
293   return true;
294 }
295 
TopoIndex(const std::vector<AnfNodePtr> & node_list)296 std::map<AnfNodePtr, int> TopoIndex(const std::vector<AnfNodePtr> &node_list) {
297   std::map<AnfNodePtr, int> topo_index;
298   for (size_t i = 0; i < node_list.size(); i++) {
299     topo_index.insert(make_pair(node_list[i], i));
300   }
301   return topo_index;
302 }
303 }  // namespace
304 
Run(const FuncGraphPtr & graph)305 bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
306   MS_EXCEPTION_IF_NULL(graph);
307   std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
308   auto topo_index = TopoIndex(node_list);
309 
310   for (auto node : node_list) {
311     AnfNodeIndex aggregate_node;
312     AnfNodePtr skip_node;
313     std::vector<AnfNodeIndex> inplace_node;
314     // 1. Pattern Match.
315     if (!PatternMatch(graph, node, &aggregate_node, &skip_node, &inplace_node)) {
316       continue;
317     }
318 
319     // 2. Keep the original topological order in case the dependence between inplace nodes
320     std::sort(inplace_node.begin(), inplace_node.end(), [&topo_index](const AnfNodeIndex &n1, const AnfNodeIndex &n2) {
321       auto iter1 = topo_index.find(n1.node);
322       auto iter2 = topo_index.find(n2.node);
323       if (iter1 == topo_index.end() || iter2 == topo_index.end()) {
324         MS_LOG(EXCEPTION) << ": Node not existed in topo order. node1: " << n1.node->DebugString()
325                           << ", node2: " << n2.node->DebugString();
326       }
327 
328       if (iter1->second < iter2->second) {
329         return true;
330       }
331       return false;
332     });
333     MS_LOG(INFO) << "[inplace optimizer] aggregate node: " << aggregate_node.index << ", "
334                  << aggregate_node.node->DebugString() << "; skip node: " << skip_node->DebugString() << std::endl
335                  << "; inplace node 0: " << inplace_node[0].index << ", " << inplace_node[0].node->DebugString()
336                  << std::endl
337                  << "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString()
338                  << std::endl;
339     // 2. Set Node attr
340     SetNodeAttr(aggregate_node, skip_node, &inplace_node, graph);
341   }
342 
343   return true;
344 }
345 }  // namespace opt
346 }  // namespace mindspore
347