1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/eliminate_redundant_output.h"
17
18 #include <memory>
19 #include <algorithm>
20 #include <vector>
21 #include <string>
22 #include <utility>
23
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "ir/anf.h"
28 #include "ir/graph_utils.h"
29 #include "utils/anf_utils.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
31 #include "backend/common/graph_kernel/core/graph_builder.h"
32
33 namespace mindspore::graphkernel {
34 namespace {
GetIndex(const AnfNodePtr & getitem_node)35 inline size_t GetIndex(const AnfNodePtr &getitem_node) {
36 MS_EXCEPTION_IF_NULL(getitem_node);
37 if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
38 MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope();
39 }
40 return LongToSize(GetValue<int64_t>(
41 getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
42 }
43
SetIndex(const AnfNodePtr & getitem_node,size_t index)44 void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
45 auto getitem = getitem_node->cast<CNodePtr>();
46 MS_EXCEPTION_IF_NULL(getitem);
47 auto idx_node = NewValueNode(MakeValue<int64_t>(SizeToLong(index)));
48 auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
49 idx_node->set_abstract(abstract);
50 Callback::Instance()->SetEmptyKernelInfo(idx_node);
51 getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
52 }
53 } // namespace
54
GetGraphKernelGetitemList(const FuncGraphManagerPtr & mng,const AnfNodePtr & node,AnfNodePtrList * getitem_list,bool merge_repeated_getitem)55 bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
56 bool merge_repeated_getitem) {
57 MS_EXCEPTION_IF_NULL(mng);
58 MS_EXCEPTION_IF_NULL(getitem_list);
59 auto func_graph = GetCNodeFuncGraph(node);
60 MS_EXCEPTION_IF_NULL(func_graph);
61 auto output = func_graph->output();
62 if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
63 MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope();
64 }
65 auto output_num = output->cast<CNodePtr>()->size() - 1;
66 getitem_list->clear();
67 getitem_list->resize(output_num, nullptr);
68 auto users = mng->node_users()[node];
69 bool changed = false;
70 for (const auto &user : users) {
71 if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
72 MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem, but got: " << user.first->DebugString();
73 }
74 auto &getitem = user.first;
75 auto idx = GetIndex(getitem);
76 if (idx >= output_num) {
77 MS_LOG(EXCEPTION) << "Index of GetItem is " << idx << ", which is out of range of MakeTuple [0, " << output_num
78 << "). GetItem node: " << getitem->DebugString();
79 }
80 if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) {
81 (void)mng->Replace(getitem, (*getitem_list)[idx]);
82 changed = true;
83 } else {
84 (*getitem_list)[idx] = getitem;
85 }
86 }
87 return changed;
88 }
89
FindGraphKernelsWithMultiOutput(const FuncGraphPtr & func_graph)90 AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
91 auto todos = TopoSort(func_graph->get_return());
92 AnfNodePtrList result;
93 (void)std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
94 return AnfUtils::IsGraphKernel(node) && IsPrimitiveCNode(GetCNodeFuncGraph(node)->output(), prim::kPrimMakeTuple);
95 });
96 return result;
97 }
98
IsSideEffectNode(const AnfNodePtr & node)99 bool IsSideEffectNode(const AnfNodePtr &node) {
100 std::vector<PrimitivePtr> side_effect_nodes = {prim::kPrimAssign};
101 return std::any_of(side_effect_nodes.begin(), side_effect_nodes.end(),
102 [&node](const PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
103 }
104
105 /* Unify the repeated output in a func_graph.
106 * %1 = call @graph_kernel(p1, p2)
107 * %2 = tuple_getitem(%1, 0)
108 * %3 = tuple_getitem(%1, 1)
109 * graph_kernel:
110 * %1 = TensorAdd(p1, p2)
111 * %2 = Reshape(%1)
112 * return make_tuple(%2, %2)
113 * -->
114 * %1 = call @graph_kernel(p1, p2)
115 * %2 = tuple_getitem(%1, 0)
116 * %3 = tuple_getitem(%1, 0) // changed the index to 0.
117 * graph_kernel:
118 * %1 = TensorAdd(p1, p2)
119 * %2 = Reshape(%1)
120 * return make_tuple(%2, %2)
121 */
122 class UnifyRepeatedOutput : public opt::Pass {
123 public:
Run(const FuncGraphPtr & func_graph)124 bool Run(const FuncGraphPtr &func_graph) override {
125 auto mng = func_graph->manager();
126 MS_EXCEPTION_IF_NULL(mng);
127 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
128 bool changed = false;
129 for (auto node : todos) {
130 if (CheckRepeatedOutput(GetCNodeFuncGraph(node))) {
131 changed = true;
132 AnfNodePtrList getitem_list;
133 (void)GetGraphKernelGetitemList(mng, node, &getitem_list, false);
134 if (getitem_list.size() != index_map_.size()) {
135 MS_LOG(EXCEPTION) << "getitem_list.size (" << getitem_list.size() << ") should be equal to index_map.size ("
136 << index_map_.size() << ").";
137 }
138 for (size_t i = 0; i < index_map_.size(); ++i) {
139 if (index_map_[i] != i && getitem_list[i] != nullptr) {
140 SetIndex(getitem_list[i], index_map_[i]);
141 }
142 }
143 }
144 }
145 return changed;
146 }
147
148 private:
CheckRepeatedOutput(const FuncGraphPtr & sub_func_graph)149 bool CheckRepeatedOutput(const FuncGraphPtr &sub_func_graph) {
150 // the output should be a MakeTuple.
151 auto maketuple = sub_func_graph->output()->cast<CNodePtr>();
152 MS_EXCEPTION_IF_NULL(maketuple);
153 AnfNodePtrList outputs(maketuple->inputs().begin() + 1, maketuple->inputs().end());
154 index_map_.resize(outputs.size());
155 bool found = false;
156 for (size_t i = 0; i < outputs.size(); ++i) {
157 index_map_[i] =
158 static_cast<size_t>(std::find(outputs.begin(), outputs.begin() + SizeToLong(i), outputs[i]) - outputs.begin());
159 if (index_map_[i] != i) {
160 found = true;
161 }
162 }
163 return found;
164 }
165 std::vector<size_t> index_map_;
166 };
167
168 /* Unify the get_item nodes that have same index.
169 * %1 = call @graph_kernel(p1, p2)
170 * %2 = tuple_getitem(%1, 0)
171 * %3 = tuple_getitem(%1, 0)
172 * %4 = tuple_getitem(%1, 1)
173 * %5 = user_x(%2)
174 * %6 = user_y(%3)
175 * %7 = user_z(%4)
176 * --->
177 * %1 = call @graph_kernel(p1, p2)
178 * %2 = tuple_getitem(%1, 0) // unify the original %2 and %3
179 * %3 = tuple_getitem(%1, 1)
180 * %4 = user_x(%2)
181 * %5 = user_y(%2)
182 * %6 = user_z(%3)
183 */
184 class UnifyRepeatedGetitem : public opt::Pass {
185 public:
Run(const FuncGraphPtr & func_graph)186 bool Run(const FuncGraphPtr &func_graph) override {
187 auto mng = func_graph->manager();
188 MS_EXCEPTION_IF_NULL(mng);
189 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
190 bool changed = false;
191 for (auto node : todos) {
192 AnfNodePtrList getitem_list;
193 changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed;
194 }
195 return changed;
196 }
197 };
198
199 /* if a graphkernel node of multi-output is directly used by other kernel,
200 * change it to use getitem-maketuple.
201 * %1 = call @graph_kernel(p1, p2) // assume it has 3 outputs.
202 * %2 = AddN(%1)
203 * --->
204 * %1 = call @graph_kernel(p1, p2)
205 * %2 = tuple_getitem(%1, 0)
206 * %3 = tuple_getitem(%1, 1)
207 * %4 = tuple_getitem(%1, 2)
208 * %5 = make_tuple(%2, %3, %4)
209 * %6 = AddN(%5)
210 */
211 class TupleNodeFormatter : public opt::Pass {
212 public:
Run(const FuncGraphPtr & func_graph)213 bool Run(const FuncGraphPtr &func_graph) override {
214 auto mng = func_graph->manager();
215 MS_EXCEPTION_IF_NULL(mng);
216 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
217 bool changed = false;
218 for (auto &node : todos) {
219 auto &users = mng->node_users()[node];
220 for (auto &user : users) {
221 if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
222 auto mt = TransToMaketuple(node);
223 (void)mng->Replace(node, mt);
224 changed = true;
225 break;
226 }
227 }
228 }
229 if (changed) {
230 (void)EliminateMaketupleGetitem(func_graph);
231 }
232 return changed;
233 }
234
TransToMaketuple(const AnfNodePtr & node) const235 AnfNodePtr TransToMaketuple(const AnfNodePtr &node) const {
236 auto fg = node->func_graph();
237 MS_EXCEPTION_IF_NULL(fg);
238 auto node_abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
239 MS_EXCEPTION_IF_NULL(node_abs);
240 auto output_num = node_abs->size();
241 AnfNodePtrList mt_inputs{NewValueNode(prim::kPrimMakeTuple)};
242 mt_inputs.reserve(output_num + 1);
243 for (size_t i = 0; i < output_num; i++) {
244 auto idx = MakeValue(SizeToLong(i));
245 AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
246 gt_inputs.back()->set_abstract(idx->ToAbstract());
247 auto > = mt_inputs.emplace_back(fg->NewCNode(gt_inputs));
248 gt->set_abstract(node_abs->elements()[i]);
249 Callback::Instance()->SetEmptyKernelInfo(gt);
250 }
251 auto mt = fg->NewCNode(mt_inputs);
252 mt->set_abstract(node_abs);
253 Callback::Instance()->SetEmptyKernelInfo(mt);
254 return mt;
255 }
256 };
257
Run(const FuncGraphPtr & func_graph)258 bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
259 auto mng = func_graph->manager();
260 if (mng == nullptr) {
261 mng = Manage(func_graph, true);
262 func_graph->set_manager(mng);
263 }
264 bool changed = false;
265 changed = std::make_shared<TupleNodeFormatter>()->Run(func_graph) || changed;
266 changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
267 changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
268 changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
269 changed = std::make_shared<EliminateHangingOutput>()->Run(func_graph) || changed;
270 return changed;
271 }
272
UpdateGetitemIndex(const AnfNodePtr & getitem,size_t offset) const273 void EliminateHangingOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) const {
274 if (offset == 0) {
275 return;
276 }
277 MS_EXCEPTION_IF_NULL(getitem);
278 auto index = GetIndex(getitem);
279 if (offset > index) {
280 MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString() << ". "
281 << offset << " vs " << index;
282 }
283 index -= offset;
284 SetIndex(getitem, index);
285 }
286
ReplaceMakeTuple(const AnfNodePtr & node,const AnfNodePtrList & getitems) const287 AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) const {
288 auto func_graph = GetCNodeFuncGraph(node);
289 MS_EXCEPTION_IF_NULL(func_graph);
290 auto old_maketuple = func_graph->output()->cast<CNodePtr>();
291 MS_EXCEPTION_IF_NULL(old_maketuple);
292 AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
293 AbstractBasePtrList abstract_list;
294 size_t offset = 0;
295 for (size_t i = 0; i < getitems.size(); ++i) {
296 // If a node has no user, it should be eliminated, but except for side-effect node.
297 if (getitems[i] == nullptr && !IsSideEffectNode(old_maketuple->input(i + 1))) {
298 offset++;
299 } else {
300 new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
301 abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
302 if (getitems[i] != nullptr) {
303 UpdateGetitemIndex(getitems[i], offset);
304 }
305 }
306 }
307 if (offset == 0) {
308 return nullptr;
309 }
310 if (new_maketuple_inputs.size() == 1) {
311 MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty";
312 }
313 const size_t maketuple_one_input_size = 2;
314 if (new_maketuple_inputs.size() == maketuple_one_input_size) {
315 func_graph->set_output(new_maketuple_inputs.back());
316 } else {
317 auto make_tuple = func_graph->NewCNode(new_maketuple_inputs);
318 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
319 Callback::Instance()->SetEmptyKernelInfo(make_tuple);
320 func_graph->set_output(make_tuple);
321 }
322
323 auto old_cnode = node->cast<CNodePtr>();
324 MS_EXCEPTION_IF_NULL(old_cnode);
325 AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
326 auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs);
327 return graph_kernel_node;
328 }
329
Run(const FuncGraphPtr & func_graph)330 bool EliminateHangingOutput::Run(const FuncGraphPtr &func_graph) {
331 auto mng = func_graph->manager();
332 MS_EXCEPTION_IF_NULL(mng);
333 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
334 bool changed = false;
335 for (auto node : todos) {
336 AnfNodePtrList getitems;
337 (void)GetGraphKernelGetitemList(mng, node, &getitems, false);
338 auto new_node = ReplaceMakeTuple(node, getitems);
339 if (new_node != nullptr) {
340 if (!IsPrimitiveCNode(GetCNodeFuncGraph(new_node)->output(), prim::kPrimMakeTuple)) {
341 // only one output, remove the getitem.
342 auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; });
343 if (i != getitems.end()) {
344 (void)mng->Replace(*i, new_node);
345 }
346 } else {
347 (void)mng->Replace(node, new_node);
348 }
349 changed = true;
350 }
351 }
352 return changed;
353 }
354 } // namespace mindspore::graphkernel
355