1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
17
18 #include <memory>
19 #include <algorithm>
20 #include <vector>
21 #include <string>
22 #include <utility>
23
24 #include "base/core_ops.h"
25 #include "ir/graph_utils.h"
26 #include "backend/optimizer/common/helper.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "debug/anf_ir_dump.h"
29 #include "backend/kernel_compiler/common_utils.h"
30 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
31 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
32
33 namespace mindspore {
34 namespace opt {
35 namespace {
GetIndex(const AnfNodePtr & getitem_node)36 inline size_t GetIndex(const AnfNodePtr &getitem_node) {
37 MS_EXCEPTION_IF_NULL(getitem_node);
38 if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
39 MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope();
40 }
41 return LongToSize(GetValue<int64_t>(
42 getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
43 }
44
SetIndex(const AnfNodePtr & getitem_node,size_t index)45 void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
46 auto getitem = getitem_node->cast<CNodePtr>();
47 MS_EXCEPTION_IF_NULL(getitem);
48 auto idx_node = NewValueNode(MakeValue<int64_t>(SizeToLong(index)));
49 auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
50 idx_node->set_abstract(abstract);
51 idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
52 getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
53 }
54 } // namespace
55
GetGraphKernelGetitemList(const FuncGraphManagerPtr & mng,const AnfNodePtr & node,AnfNodePtrList * getitem_list,bool merge_repeated_getitem)56 bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
57 bool merge_repeated_getitem) {
58 MS_EXCEPTION_IF_NULL(mng);
59 MS_EXCEPTION_IF_NULL(getitem_list);
60 auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
61 MS_EXCEPTION_IF_NULL(func_graph);
62 auto output = func_graph->output();
63 if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
64 MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope();
65 }
66 auto output_num = output->cast<CNodePtr>()->size() - 1;
67 getitem_list->clear();
68 getitem_list->resize(output_num, nullptr);
69 auto users = mng->node_users()[node];
70 bool changed = false;
71 for (const auto &user : users) {
72 if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
73 MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem, but got: " << user.first->DebugString();
74 }
75 auto &getitem = user.first;
76 auto idx = GetIndex(getitem);
77 if (idx >= output_num) {
78 MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString();
79 }
80 if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) {
81 mng->Replace(getitem, (*getitem_list)[idx]);
82 changed = true;
83 } else {
84 (*getitem_list)[idx] = getitem;
85 }
86 }
87 return changed;
88 }
89
FindGraphKernelsWithMultiOutput(const FuncGraphPtr & func_graph)90 AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
91 auto todos = TopoSort(func_graph->get_return());
92 AnfNodePtrList result;
93 std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
94 return AnfAlgo::IsGraphKernel(node) &&
95 IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(node)->output(), prim::kPrimMakeTuple);
96 });
97 return result;
98 }
99
IsSideEffectNode(const AnfNodePtr & node)100 bool IsSideEffectNode(const AnfNodePtr &node) {
101 std::vector<PrimitivePtr> side_effect_nodes = {prim::kPrimAssign, prim::kPrimInplaceAssign};
102 return std::any_of(side_effect_nodes.begin(), side_effect_nodes.end(),
103 [&node](const PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
104 }
105
106 /* Unify the repeated output in a func_graph.
107 * %1 = call @graph_kernel(p1, p2)
108 * %2 = tuple_getitem(%1, 0)
109 * %3 = tuple_getitem(%1, 1)
110 * graph_kernel:
111 * %1 = TensorAdd(p1, p2)
112 * %2 = Reshape(%1)
113 * return make_tuple(%2, %2)
114 * -->
115 * %1 = call @graph_kernel(p1, p2)
116 * %2 = tuple_getitem(%1, 0)
117 * %3 = tuple_getitem(%1, 0) // changed the index to 0.
118 * graph_kernel:
119 * %1 = TensorAdd(p1, p2)
120 * %2 = Reshape(%1)
121 * return make_tuple(%2, %2)
122 */
123 class UnifyRepeatedOutput : public Pass {
124 public:
Run(const FuncGraphPtr & func_graph)125 bool Run(const FuncGraphPtr &func_graph) override {
126 auto mng = func_graph->manager();
127 MS_EXCEPTION_IF_NULL(mng);
128 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
129 bool changed = false;
130 for (auto node : todos) {
131 if (CheckRepeatedOutput(AnfAlgo::GetCNodeFuncGraphPtr(node))) {
132 changed = true;
133 AnfNodePtrList getitem_list;
134 GetGraphKernelGetitemList(mng, node, &getitem_list, false);
135 if (getitem_list.size() != index_map_.size()) {
136 MS_LOG(EXCEPTION) << "getitem_list.size (" << getitem_list.size() << ") should be equal to index_map.size ("
137 << index_map_.size() << ").";
138 }
139 for (size_t i = 0; i < index_map_.size(); ++i) {
140 if (index_map_[i] != i && getitem_list[i] != nullptr) {
141 SetIndex(getitem_list[i], index_map_[i]);
142 }
143 }
144 }
145 }
146 return changed;
147 }
148
149 private:
CheckRepeatedOutput(const FuncGraphPtr & sub_func_graph)150 bool CheckRepeatedOutput(const FuncGraphPtr &sub_func_graph) {
151 // the output should be a MakeTuple.
152 auto maketuple = sub_func_graph->output()->cast<CNodePtr>();
153 MS_EXCEPTION_IF_NULL(maketuple);
154 AnfNodePtrList outputs(maketuple->inputs().begin() + 1, maketuple->inputs().end());
155 index_map_.resize(outputs.size());
156 bool found = false;
157 for (size_t i = 0; i < outputs.size(); ++i) {
158 index_map_[i] =
159 static_cast<size_t>(std::find(outputs.begin(), outputs.begin() + i, outputs[i]) - outputs.begin());
160 if (index_map_[i] != i) {
161 found = true;
162 }
163 }
164 return found;
165 }
166 std::vector<size_t> index_map_;
167 };
168
169 /* Unify the get_item nodes that have same index.
170 * %1 = call @graph_kernel(p1, p2)
171 * %2 = tuple_getitem(%1, 0)
172 * %3 = tuple_getitem(%1, 0)
173 * %4 = tuple_getitem(%1, 1)
174 * %5 = user_x(%2)
175 * %6 = user_y(%3)
176 * %7 = user_z(%4)
177 * --->
178 * %1 = call @graph_kernel(p1, p2)
179 * %2 = tuple_getitem(%1, 0) // unify the original %2 and %3
180 * %3 = tuple_getitem(%1, 1)
181 * %4 = user_x(%2)
182 * %5 = user_y(%2)
183 * %6 = user_z(%3)
184 */
185 class UnifyRepeatedGetitem : public Pass {
186 public:
Run(const FuncGraphPtr & func_graph)187 bool Run(const FuncGraphPtr &func_graph) override {
188 auto mng = func_graph->manager();
189 MS_EXCEPTION_IF_NULL(mng);
190 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
191 bool changed = false;
192 for (auto node : todos) {
193 AnfNodePtrList getitem_list;
194 changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed;
195 }
196 return changed;
197 }
198 };
199
Run(const FuncGraphPtr & func_graph)200 bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
201 auto mng = func_graph->manager();
202 if (mng == nullptr) {
203 mng = Manage(func_graph, true);
204 func_graph->set_manager(mng);
205 }
206 bool changed = false;
207 changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
208 changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
209 changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
210 changed = std::make_shared<EliminateHangingOutput>()->Run(func_graph) || changed;
211 return changed;
212 }
213
UpdateGetitemIndex(const AnfNodePtr & getitem,size_t offset) const214 void EliminateHangingOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) const {
215 if (offset == 0) return;
216 MS_EXCEPTION_IF_NULL(getitem);
217 auto index = GetIndex(getitem);
218 if (offset > index) {
219 MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString();
220 }
221 index -= offset;
222 SetIndex(getitem, index);
223 }
224
ReplaceMakeTuple(const AnfNodePtr & node,const AnfNodePtrList & getitems) const225 AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) const {
226 auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
227 MS_EXCEPTION_IF_NULL(func_graph);
228 auto old_maketuple = func_graph->output()->cast<CNodePtr>();
229 MS_EXCEPTION_IF_NULL(old_maketuple);
230 AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
231 AbstractBasePtrList abstract_list;
232 size_t offset = 0;
233 for (size_t i = 0; i < getitems.size(); ++i) {
234 // If a node has no user, it should be eliminated, but except for side-effect node.
235 if (getitems[i] == nullptr && !IsSideEffectNode(old_maketuple->input(i + 1))) {
236 offset++;
237 } else {
238 new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
239 abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
240 if (getitems[i] != nullptr) {
241 UpdateGetitemIndex(getitems[i], offset);
242 }
243 }
244 }
245 if (offset == 0) return nullptr;
246 if (new_maketuple_inputs.size() == 1) {
247 MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty";
248 }
249 if (new_maketuple_inputs.size() == 2) {
250 func_graph->set_output(new_maketuple_inputs.back());
251 } else {
252 auto make_tuple = func_graph->NewCNode(new_maketuple_inputs);
253 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
254 make_tuple->set_kernel_info(std::make_shared<device::KernelInfo>());
255 func_graph->set_output(make_tuple);
256 }
257
258 auto old_cnode = node->cast<CNodePtr>();
259 MS_EXCEPTION_IF_NULL(old_cnode);
260 AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
261 AnfNodePtrList outputs;
262 kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
263 auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs);
264 SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs);
265 return graph_kernel_node;
266 }
267
Run(const FuncGraphPtr & func_graph)268 bool EliminateHangingOutput::Run(const FuncGraphPtr &func_graph) {
269 auto mng = func_graph->manager();
270 MS_EXCEPTION_IF_NULL(mng);
271 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
272 bool changed = false;
273 for (auto node : todos) {
274 AnfNodePtrList getitems;
275 GetGraphKernelGetitemList(mng, node, &getitems, false);
276 auto new_node = ReplaceMakeTuple(node, getitems);
277 if (new_node != nullptr) {
278 if (!IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(new_node)->output(), prim::kPrimMakeTuple)) {
279 // only one output, remove the getitem.
280 auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; });
281 if (i != getitems.end()) {
282 mng->Replace(*i, new_node);
283 }
284 } else {
285 mng->Replace(node, new_node);
286 }
287 changed = true;
288 }
289 }
290 return changed;
291 }
292 } // namespace opt
293 } // namespace mindspore
294