• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/eliminate_redundant_output.h"
17 
18 #include <memory>
19 #include <algorithm>
20 #include <vector>
21 #include <string>
22 #include <utility>
23 
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "ir/anf.h"
28 #include "ir/graph_utils.h"
29 #include "utils/anf_utils.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
31 #include "backend/common/graph_kernel/core/graph_builder.h"
32 
33 namespace mindspore::graphkernel {
34 namespace {
GetIndex(const AnfNodePtr & getitem_node)35 inline size_t GetIndex(const AnfNodePtr &getitem_node) {
36   MS_EXCEPTION_IF_NULL(getitem_node);
37   if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
38     MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope();
39   }
40   return LongToSize(GetValue<int64_t>(
41     getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
42 }
43 
SetIndex(const AnfNodePtr & getitem_node,size_t index)44 void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
45   auto getitem = getitem_node->cast<CNodePtr>();
46   MS_EXCEPTION_IF_NULL(getitem);
47   auto idx_node = NewValueNode(MakeValue<int64_t>(SizeToLong(index)));
48   auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
49   idx_node->set_abstract(abstract);
50   Callback::Instance()->SetEmptyKernelInfo(idx_node);
51   getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
52 }
53 }  // namespace
54 
GetGraphKernelGetitemList(const FuncGraphManagerPtr & mng,const AnfNodePtr & node,AnfNodePtrList * getitem_list,bool merge_repeated_getitem)55 bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
56                                bool merge_repeated_getitem) {
57   MS_EXCEPTION_IF_NULL(mng);
58   MS_EXCEPTION_IF_NULL(getitem_list);
59   auto func_graph = GetCNodeFuncGraph(node);
60   MS_EXCEPTION_IF_NULL(func_graph);
61   auto output = func_graph->output();
62   if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
63     MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope();
64   }
65   auto output_num = output->cast<CNodePtr>()->size() - 1;
66   getitem_list->clear();
67   getitem_list->resize(output_num, nullptr);
68   auto users = mng->node_users()[node];
69   bool changed = false;
70   for (const auto &user : users) {
71     if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
72       MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem, but got: " << user.first->DebugString();
73     }
74     auto &getitem = user.first;
75     auto idx = GetIndex(getitem);
76     if (idx >= output_num) {
77       MS_LOG(EXCEPTION) << "Index of GetItem is " << idx << ", which is out of range of MakeTuple [0, " << output_num
78                         << "). GetItem node: " << getitem->DebugString();
79     }
80     if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) {
81       (void)mng->Replace(getitem, (*getitem_list)[idx]);
82       changed = true;
83     } else {
84       (*getitem_list)[idx] = getitem;
85     }
86   }
87   return changed;
88 }
89 
FindGraphKernelsWithMultiOutput(const FuncGraphPtr & func_graph)90 AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
91   auto todos = TopoSort(func_graph->get_return());
92   AnfNodePtrList result;
93   (void)std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
94     return AnfUtils::IsGraphKernel(node) && IsPrimitiveCNode(GetCNodeFuncGraph(node)->output(), prim::kPrimMakeTuple);
95   });
96   return result;
97 }
98 
IsSideEffectNode(const AnfNodePtr & node)99 bool IsSideEffectNode(const AnfNodePtr &node) {
100   std::vector<PrimitivePtr> side_effect_nodes = {prim::kPrimAssign};
101   return std::any_of(side_effect_nodes.begin(), side_effect_nodes.end(),
102                      [&node](const PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
103 }
104 
105 /* Unify the repeated output in a func_graph.
106  *   %1 = call @graph_kernel(p1, p2)
107  *   %2 = tuple_getitem(%1, 0)
108  *   %3 = tuple_getitem(%1, 1)
109  *   graph_kernel:
110  *      %1 = TensorAdd(p1, p2)
111  *      %2 = Reshape(%1)
112  *      return make_tuple(%2, %2)
113  * -->
114  *   %1 = call @graph_kernel(p1, p2)
115  *   %2 = tuple_getitem(%1, 0)
116  *   %3 = tuple_getitem(%1, 0)   // changed the index to 0.
117  *   graph_kernel:
118  *      %1 = TensorAdd(p1, p2)
119  *      %2 = Reshape(%1)
120  *      return make_tuple(%2, %2)
121  */
122 class UnifyRepeatedOutput : public opt::Pass {
123  public:
Run(const FuncGraphPtr & func_graph)124   bool Run(const FuncGraphPtr &func_graph) override {
125     auto mng = func_graph->manager();
126     MS_EXCEPTION_IF_NULL(mng);
127     auto todos = FindGraphKernelsWithMultiOutput(func_graph);
128     bool changed = false;
129     for (auto node : todos) {
130       if (CheckRepeatedOutput(GetCNodeFuncGraph(node))) {
131         changed = true;
132         AnfNodePtrList getitem_list;
133         (void)GetGraphKernelGetitemList(mng, node, &getitem_list, false);
134         if (getitem_list.size() != index_map_.size()) {
135           MS_LOG(EXCEPTION) << "getitem_list.size (" << getitem_list.size() << ") should be equal to index_map.size ("
136                             << index_map_.size() << ").";
137         }
138         for (size_t i = 0; i < index_map_.size(); ++i) {
139           if (index_map_[i] != i && getitem_list[i] != nullptr) {
140             SetIndex(getitem_list[i], index_map_[i]);
141           }
142         }
143       }
144     }
145     return changed;
146   }
147 
148  private:
CheckRepeatedOutput(const FuncGraphPtr & sub_func_graph)149   bool CheckRepeatedOutput(const FuncGraphPtr &sub_func_graph) {
150     // the output should be a MakeTuple.
151     auto maketuple = sub_func_graph->output()->cast<CNodePtr>();
152     MS_EXCEPTION_IF_NULL(maketuple);
153     AnfNodePtrList outputs(maketuple->inputs().begin() + 1, maketuple->inputs().end());
154     index_map_.resize(outputs.size());
155     bool found = false;
156     for (size_t i = 0; i < outputs.size(); ++i) {
157       index_map_[i] =
158         static_cast<size_t>(std::find(outputs.begin(), outputs.begin() + SizeToLong(i), outputs[i]) - outputs.begin());
159       if (index_map_[i] != i) {
160         found = true;
161       }
162     }
163     return found;
164   }
165   std::vector<size_t> index_map_;
166 };
167 
168 /* Unify the get_item nodes that have same index.
169  *   %1 = call @graph_kernel(p1, p2)
170  *   %2 = tuple_getitem(%1, 0)
171  *   %3 = tuple_getitem(%1, 0)
172  *   %4 = tuple_getitem(%1, 1)
173  *   %5 = user_x(%2)
174  *   %6 = user_y(%3)
175  *   %7 = user_z(%4)
176  *   --->
177  *   %1 = call @graph_kernel(p1, p2)
178  *   %2 = tuple_getitem(%1, 0) // unify the original %2 and %3
179  *   %3 = tuple_getitem(%1, 1)
180  *   %4 = user_x(%2)
181  *   %5 = user_y(%2)
182  *   %6 = user_z(%3)
183  */
184 class UnifyRepeatedGetitem : public opt::Pass {
185  public:
Run(const FuncGraphPtr & func_graph)186   bool Run(const FuncGraphPtr &func_graph) override {
187     auto mng = func_graph->manager();
188     MS_EXCEPTION_IF_NULL(mng);
189     auto todos = FindGraphKernelsWithMultiOutput(func_graph);
190     bool changed = false;
191     for (auto node : todos) {
192       AnfNodePtrList getitem_list;
193       changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed;
194     }
195     return changed;
196   }
197 };
198 
199 /* if a graphkernel node of multi-output is directly used by other kernel,
200  * change it to use getitem-maketuple.
201  *   %1 = call @graph_kernel(p1, p2)  // assume it has 3 outputs.
202  *   %2 = AddN(%1)
203  *   --->
204  *   %1 = call @graph_kernel(p1, p2)
205  *   %2 = tuple_getitem(%1, 0)
206  *   %3 = tuple_getitem(%1, 1)
207  *   %4 = tuple_getitem(%1, 2)
208  *   %5 = make_tuple(%2, %3, %4)
209  *   %6 = AddN(%5)
210  */
211 class TupleNodeFormatter : public opt::Pass {
212  public:
Run(const FuncGraphPtr & func_graph)213   bool Run(const FuncGraphPtr &func_graph) override {
214     auto mng = func_graph->manager();
215     MS_EXCEPTION_IF_NULL(mng);
216     auto todos = FindGraphKernelsWithMultiOutput(func_graph);
217     bool changed = false;
218     for (auto &node : todos) {
219       auto &users = mng->node_users()[node];
220       for (auto &user : users) {
221         if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
222           auto mt = TransToMaketuple(node);
223           (void)mng->Replace(node, mt);
224           changed = true;
225           break;
226         }
227       }
228     }
229     if (changed) {
230       (void)EliminateMaketupleGetitem(func_graph);
231     }
232     return changed;
233   }
234 
TransToMaketuple(const AnfNodePtr & node) const235   AnfNodePtr TransToMaketuple(const AnfNodePtr &node) const {
236     auto fg = node->func_graph();
237     MS_EXCEPTION_IF_NULL(fg);
238     auto node_abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
239     MS_EXCEPTION_IF_NULL(node_abs);
240     auto output_num = node_abs->size();
241     AnfNodePtrList mt_inputs{NewValueNode(prim::kPrimMakeTuple)};
242     mt_inputs.reserve(output_num + 1);
243     for (size_t i = 0; i < output_num; i++) {
244       auto idx = MakeValue(SizeToLong(i));
245       AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
246       gt_inputs.back()->set_abstract(idx->ToAbstract());
247       auto &gt = mt_inputs.emplace_back(fg->NewCNode(gt_inputs));
248       gt->set_abstract(node_abs->elements()[i]);
249       Callback::Instance()->SetEmptyKernelInfo(gt);
250     }
251     auto mt = fg->NewCNode(mt_inputs);
252     mt->set_abstract(node_abs);
253     Callback::Instance()->SetEmptyKernelInfo(mt);
254     return mt;
255   }
256 };
257 
Run(const FuncGraphPtr & func_graph)258 bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
259   auto mng = func_graph->manager();
260   if (mng == nullptr) {
261     mng = Manage(func_graph, true);
262     func_graph->set_manager(mng);
263   }
264   bool changed = false;
265   changed = std::make_shared<TupleNodeFormatter>()->Run(func_graph) || changed;
266   changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
267   changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
268   changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
269   changed = std::make_shared<EliminateHangingOutput>()->Run(func_graph) || changed;
270   return changed;
271 }
272 
UpdateGetitemIndex(const AnfNodePtr & getitem,size_t offset) const273 void EliminateHangingOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) const {
274   if (offset == 0) {
275     return;
276   }
277   MS_EXCEPTION_IF_NULL(getitem);
278   auto index = GetIndex(getitem);
279   if (offset > index) {
280     MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString() << ". "
281                       << offset << " vs " << index;
282   }
283   index -= offset;
284   SetIndex(getitem, index);
285 }
286 
ReplaceMakeTuple(const AnfNodePtr & node,const AnfNodePtrList & getitems) const287 AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) const {
288   auto func_graph = GetCNodeFuncGraph(node);
289   MS_EXCEPTION_IF_NULL(func_graph);
290   auto old_maketuple = func_graph->output()->cast<CNodePtr>();
291   MS_EXCEPTION_IF_NULL(old_maketuple);
292   AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
293   AbstractBasePtrList abstract_list;
294   size_t offset = 0;
295   for (size_t i = 0; i < getitems.size(); ++i) {
296     // If a node has no user, it should be eliminated, but except for side-effect node.
297     if (getitems[i] == nullptr && !IsSideEffectNode(old_maketuple->input(i + 1))) {
298       offset++;
299     } else {
300       new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
301       abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
302       if (getitems[i] != nullptr) {
303         UpdateGetitemIndex(getitems[i], offset);
304       }
305     }
306   }
307   if (offset == 0) {
308     return nullptr;
309   }
310   if (new_maketuple_inputs.size() == 1) {
311     MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty";
312   }
313   const size_t maketuple_one_input_size = 2;
314   if (new_maketuple_inputs.size() == maketuple_one_input_size) {
315     func_graph->set_output(new_maketuple_inputs.back());
316   } else {
317     auto make_tuple = func_graph->NewCNode(new_maketuple_inputs);
318     make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
319     Callback::Instance()->SetEmptyKernelInfo(make_tuple);
320     func_graph->set_output(make_tuple);
321   }
322 
323   auto old_cnode = node->cast<CNodePtr>();
324   MS_EXCEPTION_IF_NULL(old_cnode);
325   AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
326   auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs);
327   return graph_kernel_node;
328 }
329 
Run(const FuncGraphPtr & func_graph)330 bool EliminateHangingOutput::Run(const FuncGraphPtr &func_graph) {
331   auto mng = func_graph->manager();
332   MS_EXCEPTION_IF_NULL(mng);
333   auto todos = FindGraphKernelsWithMultiOutput(func_graph);
334   bool changed = false;
335   for (auto node : todos) {
336     AnfNodePtrList getitems;
337     (void)GetGraphKernelGetitemList(mng, node, &getitems, false);
338     auto new_node = ReplaceMakeTuple(node, getitems);
339     if (new_node != nullptr) {
340       if (!IsPrimitiveCNode(GetCNodeFuncGraph(new_node)->output(), prim::kPrimMakeTuple)) {
341         // only one output, remove the getitem.
342         auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; });
343         if (i != getitems.end()) {
344           (void)mng->Replace(*i, new_node);
345         }
346       } else {
347         (void)mng->Replace(node, new_node);
348       }
349       changed = true;
350     }
351   }
352   return changed;
353 }
354 }  // namespace mindspore::graphkernel
355