• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
17 
18 #include <memory>
19 #include <algorithm>
20 #include <vector>
21 #include <string>
22 #include <utility>
23 
24 #include "base/core_ops.h"
25 #include "ir/graph_utils.h"
26 #include "backend/optimizer/common/helper.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "debug/anf_ir_dump.h"
29 #include "backend/kernel_compiler/common_utils.h"
30 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
31 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace {
GetIndex(const AnfNodePtr & getitem_node)36 inline size_t GetIndex(const AnfNodePtr &getitem_node) {
37   MS_EXCEPTION_IF_NULL(getitem_node);
38   if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
39     MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope();
40   }
41   return LongToSize(GetValue<int64_t>(
42     getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
43 }
44 
SetIndex(const AnfNodePtr & getitem_node,size_t index)45 void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
46   auto getitem = getitem_node->cast<CNodePtr>();
47   MS_EXCEPTION_IF_NULL(getitem);
48   auto idx_node = NewValueNode(MakeValue<int64_t>(SizeToLong(index)));
49   auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
50   idx_node->set_abstract(abstract);
51   idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
52   getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
53 }
54 }  // namespace
55 
GetGraphKernelGetitemList(const FuncGraphManagerPtr & mng,const AnfNodePtr & node,AnfNodePtrList * getitem_list,bool merge_repeated_getitem)56 bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
57                                bool merge_repeated_getitem) {
58   MS_EXCEPTION_IF_NULL(mng);
59   MS_EXCEPTION_IF_NULL(getitem_list);
60   auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
61   MS_EXCEPTION_IF_NULL(func_graph);
62   auto output = func_graph->output();
63   if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
64     MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope();
65   }
66   auto output_num = output->cast<CNodePtr>()->size() - 1;
67   getitem_list->clear();
68   getitem_list->resize(output_num, nullptr);
69   auto users = mng->node_users()[node];
70   bool changed = false;
71   for (const auto &user : users) {
72     if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
73       MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem, but got: " << user.first->DebugString();
74     }
75     auto &getitem = user.first;
76     auto idx = GetIndex(getitem);
77     if (idx >= output_num) {
78       MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString();
79     }
80     if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) {
81       mng->Replace(getitem, (*getitem_list)[idx]);
82       changed = true;
83     } else {
84       (*getitem_list)[idx] = getitem;
85     }
86   }
87   return changed;
88 }
89 
FindGraphKernelsWithMultiOutput(const FuncGraphPtr & func_graph)90 AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
91   auto todos = TopoSort(func_graph->get_return());
92   AnfNodePtrList result;
93   std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
94     return AnfAlgo::IsGraphKernel(node) &&
95            IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(node)->output(), prim::kPrimMakeTuple);
96   });
97   return result;
98 }
99 
IsSideEffectNode(const AnfNodePtr & node)100 bool IsSideEffectNode(const AnfNodePtr &node) {
101   std::vector<PrimitivePtr> side_effect_nodes = {prim::kPrimAssign, prim::kPrimInplaceAssign};
102   return std::any_of(side_effect_nodes.begin(), side_effect_nodes.end(),
103                      [&node](const PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
104 }
105 
106 /* Unify the repeated output in a func_graph.
107  *   %1 = call @graph_kernel(p1, p2)
108  *   %2 = tuple_getitem(%1, 0)
109  *   %3 = tuple_getitem(%1, 1)
110  *   graph_kernel:
111  *      %1 = TensorAdd(p1, p2)
112  *      %2 = Reshape(%1)
113  *      return make_tuple(%2, %2)
114  * -->
115  *   %1 = call @graph_kernel(p1, p2)
116  *   %2 = tuple_getitem(%1, 0)
117  *   %3 = tuple_getitem(%1, 0)   // changed the index to 0.
118  *   graph_kernel:
119  *      %1 = TensorAdd(p1, p2)
120  *      %2 = Reshape(%1)
121  *      return make_tuple(%2, %2)
122  */
123 class UnifyRepeatedOutput : public Pass {
124  public:
Run(const FuncGraphPtr & func_graph)125   bool Run(const FuncGraphPtr &func_graph) override {
126     auto mng = func_graph->manager();
127     MS_EXCEPTION_IF_NULL(mng);
128     auto todos = FindGraphKernelsWithMultiOutput(func_graph);
129     bool changed = false;
130     for (auto node : todos) {
131       if (CheckRepeatedOutput(AnfAlgo::GetCNodeFuncGraphPtr(node))) {
132         changed = true;
133         AnfNodePtrList getitem_list;
134         GetGraphKernelGetitemList(mng, node, &getitem_list, false);
135         if (getitem_list.size() != index_map_.size()) {
136           MS_LOG(EXCEPTION) << "getitem_list.size (" << getitem_list.size() << ") should be equal to index_map.size ("
137                             << index_map_.size() << ").";
138         }
139         for (size_t i = 0; i < index_map_.size(); ++i) {
140           if (index_map_[i] != i && getitem_list[i] != nullptr) {
141             SetIndex(getitem_list[i], index_map_[i]);
142           }
143         }
144       }
145     }
146     return changed;
147   }
148 
149  private:
CheckRepeatedOutput(const FuncGraphPtr & sub_func_graph)150   bool CheckRepeatedOutput(const FuncGraphPtr &sub_func_graph) {
151     // the output should be a MakeTuple.
152     auto maketuple = sub_func_graph->output()->cast<CNodePtr>();
153     MS_EXCEPTION_IF_NULL(maketuple);
154     AnfNodePtrList outputs(maketuple->inputs().begin() + 1, maketuple->inputs().end());
155     index_map_.resize(outputs.size());
156     bool found = false;
157     for (size_t i = 0; i < outputs.size(); ++i) {
158       index_map_[i] =
159         static_cast<size_t>(std::find(outputs.begin(), outputs.begin() + i, outputs[i]) - outputs.begin());
160       if (index_map_[i] != i) {
161         found = true;
162       }
163     }
164     return found;
165   }
166   std::vector<size_t> index_map_;
167 };
168 
169 /* Unify the get_item nodes that have same index.
170  *   %1 = call @graph_kernel(p1, p2)
171  *   %2 = tuple_getitem(%1, 0)
172  *   %3 = tuple_getitem(%1, 0)
173  *   %4 = tuple_getitem(%1, 1)
174  *   %5 = user_x(%2)
175  *   %6 = user_y(%3)
176  *   %7 = user_z(%4)
177  *   --->
178  *   %1 = call @graph_kernel(p1, p2)
179  *   %2 = tuple_getitem(%1, 0) // unify the original %2 and %3
180  *   %3 = tuple_getitem(%1, 1)
181  *   %4 = user_x(%2)
182  *   %5 = user_y(%2)
183  *   %6 = user_z(%3)
184  */
185 class UnifyRepeatedGetitem : public Pass {
186  public:
Run(const FuncGraphPtr & func_graph)187   bool Run(const FuncGraphPtr &func_graph) override {
188     auto mng = func_graph->manager();
189     MS_EXCEPTION_IF_NULL(mng);
190     auto todos = FindGraphKernelsWithMultiOutput(func_graph);
191     bool changed = false;
192     for (auto node : todos) {
193       AnfNodePtrList getitem_list;
194       changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed;
195     }
196     return changed;
197   }
198 };
199 
Run(const FuncGraphPtr & func_graph)200 bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
201   auto mng = func_graph->manager();
202   if (mng == nullptr) {
203     mng = Manage(func_graph, true);
204     func_graph->set_manager(mng);
205   }
206   bool changed = false;
207   changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
208   changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
209   changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
210   changed = std::make_shared<EliminateHangingOutput>()->Run(func_graph) || changed;
211   return changed;
212 }
213 
UpdateGetitemIndex(const AnfNodePtr & getitem,size_t offset) const214 void EliminateHangingOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) const {
215   if (offset == 0) return;
216   MS_EXCEPTION_IF_NULL(getitem);
217   auto index = GetIndex(getitem);
218   if (offset > index) {
219     MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString();
220   }
221   index -= offset;
222   SetIndex(getitem, index);
223 }
224 
ReplaceMakeTuple(const AnfNodePtr & node,const AnfNodePtrList & getitems) const225 AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) const {
226   auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
227   MS_EXCEPTION_IF_NULL(func_graph);
228   auto old_maketuple = func_graph->output()->cast<CNodePtr>();
229   MS_EXCEPTION_IF_NULL(old_maketuple);
230   AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
231   AbstractBasePtrList abstract_list;
232   size_t offset = 0;
233   for (size_t i = 0; i < getitems.size(); ++i) {
234     // If a node has no user, it should be eliminated, but except for side-effect node.
235     if (getitems[i] == nullptr && !IsSideEffectNode(old_maketuple->input(i + 1))) {
236       offset++;
237     } else {
238       new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
239       abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
240       if (getitems[i] != nullptr) {
241         UpdateGetitemIndex(getitems[i], offset);
242       }
243     }
244   }
245   if (offset == 0) return nullptr;
246   if (new_maketuple_inputs.size() == 1) {
247     MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty";
248   }
249   if (new_maketuple_inputs.size() == 2) {
250     func_graph->set_output(new_maketuple_inputs.back());
251   } else {
252     auto make_tuple = func_graph->NewCNode(new_maketuple_inputs);
253     make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
254     make_tuple->set_kernel_info(std::make_shared<device::KernelInfo>());
255     func_graph->set_output(make_tuple);
256   }
257 
258   auto old_cnode = node->cast<CNodePtr>();
259   MS_EXCEPTION_IF_NULL(old_cnode);
260   AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
261   AnfNodePtrList outputs;
262   kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
263   auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs);
264   SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs);
265   return graph_kernel_node;
266 }
267 
Run(const FuncGraphPtr & func_graph)268 bool EliminateHangingOutput::Run(const FuncGraphPtr &func_graph) {
269   auto mng = func_graph->manager();
270   MS_EXCEPTION_IF_NULL(mng);
271   auto todos = FindGraphKernelsWithMultiOutput(func_graph);
272   bool changed = false;
273   for (auto node : todos) {
274     AnfNodePtrList getitems;
275     GetGraphKernelGetitemList(mng, node, &getitems, false);
276     auto new_node = ReplaceMakeTuple(node, getitems);
277     if (new_node != nullptr) {
278       if (!IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(new_node)->output(), prim::kPrimMakeTuple)) {
279         // only one output, remove the getitem.
280         auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; });
281         if (i != getitems.end()) {
282           mng->Replace(*i, new_node);
283         }
284       } else {
285         mng->Replace(node, new_node);
286       }
287       changed = true;
288     }
289   }
290   return changed;
291 }
292 }  // namespace opt
293 }  // namespace mindspore
294