1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/update_state_formatter.h"
17
18 #include <vector>
19 #include <set>
20 #include <memory>
21 #include <utility>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/anf.h"
26 #include "ir/named.h"
27 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
28 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
29 #include "backend/common/graph_kernel/core/eliminate_redundant_output.h"
30
31 namespace mindspore::graphkernel {
IsUpdateState(const std::pair<AnfNodePtr,int> & user)32 bool IsUpdateState(const std::pair<AnfNodePtr, int> &user) {
33 return IsPrimitiveCNode(user.first, prim::kPrimUpdateState) ||
34 (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second > 1);
35 }
36
GetUpdateStateList(const FuncGraphPtr & func_graph)37 AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
38 auto todos = TopoSort(func_graph->get_return());
39 AnfNodePtrList result;
40 (void)std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
41 return IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimDepend);
42 });
43 return result;
44 }
45
ExtendInputsOfUpdateState(const AnfNodePtrList & nodes,const FuncGraphPtr & func_graph) const46 AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
47 const FuncGraphPtr &func_graph) const {
48 AnfNodePtrList result;
49 for (auto node : nodes) {
50 if (node->abstract()->isa<abstract::AbstractTuple>()) {
51 auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
52 auto num = node_abstract.size();
53 for (size_t i = 0; i < num; i++) {
54 auto idx_val = SizeToLong(i);
55
56 auto idx = NewValueNode(idx_val);
57 MS_EXCEPTION_IF_NULL(idx);
58 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
59
60 auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
61 MS_EXCEPTION_IF_NULL(tuple_getitem);
62 tuple_getitem->set_abstract(node_abstract[i]);
63 Callback::Instance()->SetEmptyKernelInfo(tuple_getitem);
64 result.push_back(tuple_getitem);
65 }
66 } else {
67 result.push_back(node);
68 }
69 }
70 return result;
71 }
72
Run(const FuncGraphPtr & func_graph)73 bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
74 auto todos = GetUpdateStateList(func_graph);
75 bool changed = false;
76 auto mng = func_graph->manager();
77 MS_EXCEPTION_IF_NULL(mng);
78 for (auto node : todos) {
79 auto cnode = node->cast<CNodePtr>();
80 MS_EXCEPTION_IF_NULL(cnode);
81 if (cnode->size() <= kUpdateStateRealInput) {
82 continue;
83 }
84 auto inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
85 // extend inputs of UpdateState if which have multiple outputs
86 inputs = ExtendInputsOfUpdateState(inputs, func_graph);
87 if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
88 AnfNodePtrList node_inputs = {cnode->input(kAnfPrimitiveIndex), cnode->input(kUpdateStateStateInput)};
89 (void)node_inputs.insert(node_inputs.cend(), inputs.cbegin(), inputs.cend());
90 // Create a new UpdateState
91 auto new_node = func_graph->NewCNode(node_inputs);
92 new_node->set_abstract(node->abstract());
93 (void)mng->Replace(node, new_node);
94 changed = true;
95 }
96 }
97 return changed;
98 }
99
Run(const FuncGraphPtr & func_graph)100 bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
101 auto todos = GetUpdateStateList(func_graph);
102 bool changed = false;
103 auto mng = func_graph->manager();
104 MS_EXCEPTION_IF_NULL(mng);
105 for (const auto &node : todos) {
106 auto cnode = node->cast<CNodePtr>();
107 MS_EXCEPTION_IF_NULL(cnode);
108 if (cnode->size() <= kUpdateStateRealInput + 1) {
109 continue;
110 }
111 AnfNodePtrList mt_inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
112 AbstractBasePtrList abs_list;
113 (void)std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
114 [](const AnfNodePtr &inp) { return inp->abstract(); });
115 (void)mt_inputs.insert(mt_inputs.cbegin(), NewValueNode(prim::kPrimMakeTuple));
116 auto mt_node = func_graph->NewCNode(mt_inputs);
117 mt_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
118 Callback::Instance()->SetEmptyKernelInfo(mt_node);
119
120 AnfNodePtrList inputs = {cnode->input(0), cnode->input(1), mt_node};
121 auto new_node = func_graph->NewCNode(inputs);
122 new_node->set_abstract(node->abstract());
123 Callback::Instance()->SetEmptyKernelInfo(new_node);
124 (void)mng->Replace(node, new_node);
125 changed = true;
126 }
127 return changed;
128 }
129
Run(const FuncGraphPtr & func_graph)130 bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
131 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
132 auto mng = func_graph->manager();
133 MS_EXCEPTION_IF_NULL(mng);
134 bool changed = false;
135 for (const auto &node : todos) {
136 (void)GetGraphKernelGetitemList(mng, node, &getitems_, false);
137 if (getitems_.empty()) {
138 continue;
139 }
140 FindIndexesToUpdateState(mng);
141 if (indexes_.empty()) {
142 continue;
143 }
144 auto sub_func_graph = GetCNodeFuncGraph(node);
145 FilterIndexes(sub_func_graph);
146 if (indexes_.empty()) {
147 continue;
148 }
149 for (auto idx : indexes_) {
150 changed = ProcessIndex(func_graph, sub_func_graph, idx) || changed;
151 }
152 }
153 if (changed) {
154 GkUtils::UpdateFuncGraphManager(mng, func_graph);
155 auto spread_update_state = std::make_shared<SpreadUpdateState>();
156 MS_EXCEPTION_IF_NULL(spread_update_state);
157 (void)spread_update_state->Run(func_graph);
158 auto elim_hanging_output = std::make_shared<EliminateHangingOutput>();
159 MS_EXCEPTION_IF_NULL(elim_hanging_output);
160 (void)elim_hanging_output->Run(func_graph);
161 }
162 return changed;
163 }
164
FindIndexesToUpdateState(const FuncGraphManagerPtr & mng)165 void ExtendOutputForUpdateState::FindIndexesToUpdateState(const FuncGraphManagerPtr &mng) {
166 indexes_.clear();
167 external_user_type_.clear();
168 external_user_type_.resize(getitems_.size(), ExternalUserType::kNormalOp);
169 for (size_t i = 0; i < getitems_.size(); ++i) {
170 const AnfNodePtr &getitem = getitems_[i];
171 if (getitem == nullptr) {
172 continue;
173 }
174
175 const auto &getitem_user = mng->node_users()[getitem];
176 if (std::all_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
177 external_user_type_[i] = ExternalUserType::kUpdateState;
178 indexes_.push_back(i);
179 } else if (std::any_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
180 external_user_type_[i] = ExternalUserType::kMix;
181 indexes_.push_back(i);
182 }
183 }
184 }
185
FilterIndexes(const FuncGraphPtr & func_graph)186 void ExtendOutputForUpdateState::FilterIndexes(const FuncGraphPtr &func_graph) {
187 auto output_node = func_graph->output()->cast<CNodePtr>();
188 // do not process the side-effect nodes.
189 (void)indexes_.erase(std::remove_if(indexes_.begin(), indexes_.end(),
190 [&output_node](size_t i) { return IsSideEffectNode(output_node->input(i + 1)); }),
191 indexes_.cend());
192 }
193
FindAllOutputs(const FuncGraphPtr & func_graph,size_t index)194 std::vector<size_t> ExtendOutputForUpdateState::FindAllOutputs(const FuncGraphPtr &func_graph, size_t index) {
195 auto output_node = func_graph->output()->cast<CNodePtr>();
196 auto index_node = output_node->input(index);
197 std::vector<size_t> group;
198
199 // if the `out_node` is a user (direct or indirect) of the `index_node`, returns true
200 auto DependsOnIndexNode = [&index_node](const AnfNodePtr &out_node) -> bool {
201 bool result = false;
202 auto IncludeFunc = [&result, &index_node](const AnfNodePtr &node) {
203 if (node == index_node) {
204 result = true;
205 return EXCLUDE;
206 }
207 return result ? EXCLUDE : FOLLOW;
208 };
209 static_cast<void>(DeepLinkedGraphSearch(out_node, IncludeFunc));
210 return result;
211 };
212
213 for (size_t i = 1; i < output_node->size(); i++) {
214 auto out = output_node->input(i);
215 // only process the nodes that depend on index_node.
216 if (!DependsOnIndexNode(out)) {
217 continue;
218 }
219
220 // 1. always extend to the side-effect nodes
221 // 2. if the external users are only UpdateState, the related output will be eliminated,
222 // so only the getitem with realkernel user can be extended to.
223 if (IsSideEffectNode(out) ||
224 (getitems_[i - 1] != nullptr && external_user_type_[i - 1] != ExternalUserType::kUpdateState)) {
225 group.push_back(i - 1);
226 }
227 }
228 return group;
229 }
230
ProcessIndex(const FuncGraphPtr & func_graph,const FuncGraphPtr & sub_func_graph,size_t index)231 bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph,
232 size_t index) {
233 auto group = FindAllOutputs(sub_func_graph, index + 1);
234 AnfNodePtr new_node = nullptr;
235 if (group.size() == 1 && group[0] == index) {
236 return false;
237 }
238 if (group.empty()) {
239 // the output is not side-effect node, but it doesn't have realkernel user.
240 // replace the getitem with a value node that is unrelated to the original node,
241 // here a value node with value None is used,
242 // this value node will be removed in later pass(MergeOutputForUpdateState).
243 MS_LOG(INFO) << "The " << getitems_[index]->fullname_with_scope() << " only has UpdateState user.";
244 new_node = NewValueNode(kNone)->cast<AnfNodePtr>();
245 new_node->set_abstract(kNone->ToAbstract());
246 } else {
247 // Create MakeTuple, even though the group size is 1, the following pass will spread the MakeTuple,
248 // so it's unnecessary to set abstract for it.
249 AnfNodePtrList mt_input = {NewValueNode(prim::kPrimMakeTuple)};
250 (void)std::transform(group.begin(), group.end(), std::back_inserter(mt_input),
251 [this](size_t idx) { return getitems_[idx]; });
252 new_node = func_graph->NewCNode(mt_input)->cast<AnfNodePtr>();
253 }
254 auto mng = func_graph->manager();
255 MS_EXCEPTION_IF_NULL(mng);
256 for (auto user : mng->node_users()[getitems_[index]]) {
257 if (IsUpdateState(user)) {
258 user.first->cast<CNodePtr>()->set_input(IntToSize(user.second), new_node);
259 }
260 }
261 return true;
262 }
263
Run(const FuncGraphPtr & func_graph)264 bool MergeOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
265 constexpr size_t min_input_num = 2;
266 auto todos = GetUpdateStateList(func_graph);
267 bool changed = false;
268 for (auto node : todos) {
269 auto cnode = node->cast<CNodePtr>();
270 MS_EXCEPTION_IF_NULL(cnode);
271 AnfNodePtrList inputs = {cnode->input(0), cnode->input(1)};
272 std::set<AnfNodePtr> node_set;
273 for (size_t i = min_input_num; i < cnode->size(); ++i) {
274 auto input = cnode->input(i);
275 if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
276 // only keep one GetItem for that link to the same node.
277 auto gt_input = input->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
278 if (node_set.insert(gt_input).second) {
279 inputs.push_back(input);
280 }
281 } else {
282 if (input->isa<ValueNode>()) {
283 auto value_node = input->cast<ValueNodePtr>();
284 auto value = value_node->value();
285 // filter the None valuenode generated by "ExtendOutputForUpdateState"
286 if (value->isa<None>()) {
287 continue;
288 }
289 }
290 if (node_set.insert(input).second) {
291 inputs.push_back(input);
292 }
293 }
294 }
295
296 if (inputs.size() == min_input_num) {
297 inputs.push_back(inputs.back());
298 cnode->set_inputs(inputs);
299 changed = true;
300 } else if (inputs.size() < cnode->size()) {
301 cnode->set_inputs(inputs);
302 changed = true;
303 }
304 }
305 if (changed) {
306 auto mng = func_graph->manager();
307 MS_EXCEPTION_IF_NULL(mng);
308 mng->RemoveRoots();
309 mng->KeepRoots({func_graph});
310 }
311 return changed;
312 }
313 } // namespace mindspore::graphkernel
314