1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
17
18 #include <vector>
19 #include <set>
20 #include <memory>
21 #include <utility>
22 #include <algorithm>
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/kernel_compiler/common_utils.h"
25 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
26 #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
27
28 namespace mindspore {
29 namespace opt {
GetUpdateStateList(const FuncGraphPtr & func_graph)30 AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
31 auto todos = TopoSort(func_graph->get_return());
32 AnfNodePtrList result;
33 std::copy_if(todos.begin(), todos.end(), std::back_inserter(result),
34 [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimUpdateState); });
35 return result;
36 }
37
SpreadTuples(const AnfNodePtrList & nodes,size_t begin_index)38 AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
39 AnfNodePtrList result;
40 for (size_t i = begin_index; i < nodes.size(); i++) {
41 if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
42 auto mt = nodes[i]->cast<CNodePtr>();
43 // recursively spread all inner tuples.
44 auto mt_inputs = SpreadTuples(mt->inputs(), 1);
45 result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
46 } else {
47 result.push_back(nodes[i]);
48 }
49 }
50 return result;
51 }
52
ExtendInputsOfUpdateState(const AnfNodePtrList & nodes,const FuncGraphPtr & func_graph)53 AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
54 const FuncGraphPtr &func_graph) {
55 AnfNodePtrList result;
56 for (auto node : nodes) {
57 if (node->abstract()->isa<abstract::AbstractTuple>()) {
58 auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
59 auto num = node_abstract.size();
60 for (size_t i = 0; i < num; i++) {
61 auto idx_val = SizeToLong(i);
62
63 auto idx = NewValueNode(idx_val);
64 MS_EXCEPTION_IF_NULL(idx);
65 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
66
67 auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
68 MS_EXCEPTION_IF_NULL(tuple_getitem);
69 tuple_getitem->set_abstract(node_abstract[i]);
70 tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>());
71 result.push_back(tuple_getitem);
72 }
73 } else {
74 result.push_back(node);
75 }
76 }
77 return result;
78 }
79
Run(const FuncGraphPtr & func_graph)80 bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
81 auto todos = GetUpdateStateList(func_graph);
82 bool changed = false;
83 auto mng = func_graph->manager();
84 MS_EXCEPTION_IF_NULL(mng);
85 for (auto node : todos) {
86 auto cnode = node->cast<CNodePtr>();
87 MS_EXCEPTION_IF_NULL(cnode);
88 if (cnode->size() <= kUpdateStateRealInput) continue;
89 auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
90 // extend inputs of UpdateState if which have multiple outputs
91 inputs = ExtendInputsOfUpdateState(inputs, func_graph);
92 if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
93 AnfNodePtrList node_inputs = {cnode->input(kAnfPrimitiveIndex), cnode->input(kUpdateStateStateInput)};
94 node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
95 // Create a new UpdateState
96 auto new_node = func_graph->NewCNode(node_inputs);
97 new_node->set_abstract(node->abstract());
98 (void)mng->Replace(node, new_node);
99 changed = true;
100 }
101 }
102 return changed;
103 }
104
Run(const FuncGraphPtr & func_graph)105 bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
106 auto todos = GetUpdateStateList(func_graph);
107 bool changed = false;
108 auto mng = func_graph->manager();
109 MS_EXCEPTION_IF_NULL(mng);
110 for (const auto &node : todos) {
111 auto cnode = node->cast<CNodePtr>();
112 MS_EXCEPTION_IF_NULL(cnode);
113 if (cnode->size() <= kUpdateStateRealInput + 1) continue;
114 AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
115 AbstractBasePtrList abs_list;
116 std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
117 [](const AnfNodePtr &inp) { return inp->abstract(); });
118 mt_inputs.insert(mt_inputs.begin(), NewValueNode(prim::kPrimMakeTuple));
119 auto mt_node = func_graph->NewCNode(mt_inputs);
120 mt_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
121 mt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
122
123 AnfNodePtrList inputs = {cnode->input(0), cnode->input(1), mt_node};
124 auto new_node = func_graph->NewCNode(inputs);
125 new_node->set_abstract(node->abstract());
126 new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
127 (void)mng->Replace(node, new_node);
128 changed = true;
129 }
130 return changed;
131 }
132
Run(const FuncGraphPtr & func_graph)133 bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
134 auto todos = FindGraphKernelsWithMultiOutput(func_graph);
135 auto mng = func_graph->manager();
136 MS_EXCEPTION_IF_NULL(mng);
137 bool changed = false;
138 for (const auto &node : todos) {
139 GetGraphKernelGetitemList(mng, node, &getitems_, false);
140 if (getitems_.empty()) continue;
141 FindIndexesToUpdateState(mng);
142 if (indexes_.empty()) continue;
143 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
144 FilterIndexes(sub_func_graph);
145 if (indexes_.empty()) continue;
146 for (auto idx : indexes_) {
147 changed = ProcessIndex(func_graph, sub_func_graph, idx) || changed;
148 }
149 }
150 if (changed) {
151 std::make_shared<SpreadUpdateState>()->Run(func_graph);
152 std::make_shared<EliminateHangingOutput>()->Run(func_graph);
153 }
154 return changed;
155 }
156
FindIndexesToUpdateState(const FuncGraphManagerPtr & mng)157 void ExtendOutputForUpdateState::FindIndexesToUpdateState(const FuncGraphManagerPtr &mng) {
158 indexes_.clear();
159 external_user_type_.clear();
160 external_user_type_.resize(getitems_.size(), ExternalUserType::kNormalOp);
161 for (size_t i = 0; i < getitems_.size(); ++i) {
162 const AnfNodePtr &getitem = getitems_[i];
163 if (getitem == nullptr) continue;
164
165 const auto &getitem_user = mng->node_users()[getitem];
166 auto IsUpdateState = [](const std::pair<AnfNodePtr, int> &user) {
167 return IsPrimitiveCNode(user.first, prim::kPrimUpdateState);
168 };
169 if (std::all_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
170 external_user_type_[i] = ExternalUserType::kUpdateState;
171 indexes_.push_back(i);
172 } else if (std::any_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
173 external_user_type_[i] = ExternalUserType::kMix;
174 indexes_.push_back(i);
175 }
176 }
177 }
178
FilterIndexes(const FuncGraphPtr & func_graph)179 void ExtendOutputForUpdateState::FilterIndexes(const FuncGraphPtr &func_graph) {
180 auto output_node = func_graph->output()->cast<CNodePtr>();
181 // do not process the side-effect nodes.
182 indexes_.erase(std::remove_if(indexes_.begin(), indexes_.end(),
183 [&output_node](size_t i) { return IsSideEffectNode(output_node->input(i + 1)); }),
184 indexes_.end());
185 }
186
FindAllOutputs(const FuncGraphPtr & func_graph,size_t index)187 std::vector<size_t> ExtendOutputForUpdateState::FindAllOutputs(const FuncGraphPtr &func_graph, size_t index) {
188 auto output_node = func_graph->output()->cast<CNodePtr>();
189 auto index_node = output_node->input(index);
190 std::vector<size_t> group;
191
192 // if the `out_node` is a user (direct or indirect) of the `index_node`, returns true
193 auto DependsOnIndexNode = [&index_node](const AnfNodePtr &out_node) -> bool {
194 bool result = false;
195 auto IncludeFunc = [&result, &index_node](const AnfNodePtr &node) {
196 if (node == index_node) {
197 result = true;
198 return EXCLUDE;
199 }
200 return result ? EXCLUDE : FOLLOW;
201 };
202 static_cast<void>(DeepLinkedGraphSearch(out_node, IncludeFunc));
203 return result;
204 };
205
206 for (size_t i = 1; i < output_node->size(); i++) {
207 auto out = output_node->input(i);
208 // only process the nodes that depend on index_node.
209 if (!DependsOnIndexNode(out)) continue;
210
211 // 1. always extend to the side-effect nodes
212 // 2. if the external users are only UpdateState, the related output will be eliminated,
213 // so only the getitem with realkernel user can be extended to.
214 if (IsSideEffectNode(out) ||
215 (getitems_[i - 1] != nullptr && external_user_type_[i - 1] != ExternalUserType::kUpdateState)) {
216 group.push_back(i - 1);
217 }
218 }
219 return group;
220 }
221
ProcessIndex(const FuncGraphPtr & func_graph,const FuncGraphPtr & sub_func_graph,size_t index)222 bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph,
223 size_t index) {
224 auto group = FindAllOutputs(sub_func_graph, index + 1);
225 AnfNodePtr new_node = nullptr;
226 if (group.size() == 1 && group[0] == index) {
227 return false;
228 }
229 if (group.empty()) {
230 // the output is not side-effect node, but it hasn't realkernel user.
231 // replace the getitem with a value node that is unrelated to the original node.
232 // and this value node will be removed at the later pass.
233 MS_LOG(INFO) << "The " << getitems_[index]->fullname_with_scope() << " only has UpdateState user.";
234 new_node = NewValueNode(kUMonad)->cast<AnfNodePtr>();
235 new_node->set_abstract(kUMonad->ToAbstract());
236 } else {
237 // Create MakeTuple, even though the group size is 1, the following pass will spread the MakeTuple,
238 // so it's unnecessary to set abstract for it.
239 AnfNodePtrList mt_input = {NewValueNode(prim::kPrimMakeTuple)};
240 std::transform(group.begin(), group.end(), std::back_inserter(mt_input),
241 [this](size_t idx) { return getitems_[idx]; });
242 new_node = func_graph->NewCNode(mt_input)->cast<AnfNodePtr>();
243 }
244 auto mng = func_graph->manager();
245 MS_EXCEPTION_IF_NULL(mng);
246 for (auto user : mng->node_users()[getitems_[index]]) {
247 if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
248 user.first->cast<CNodePtr>()->set_input(IntToSize(user.second), new_node);
249 }
250 }
251 return true;
252 }
253
Run(const FuncGraphPtr & func_graph)254 bool MergeOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
255 auto todos = GetUpdateStateList(func_graph);
256 bool changed = false;
257 for (auto node : todos) {
258 auto cnode = node->cast<CNodePtr>();
259 MS_EXCEPTION_IF_NULL(cnode);
260 AnfNodePtrList inputs = {cnode->input(0), cnode->input(1)};
261 std::set<AnfNodePtr> node_set;
262 for (size_t i = 2; i < cnode->size(); ++i) {
263 auto input = cnode->input(i);
264 if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
265 // only keep one GetItem for that link to the same node.
266 auto gt_input = input->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
267 if (node_set.insert(gt_input).second) {
268 inputs.push_back(input);
269 }
270 } else if (!HasAbstractUMonad(input)) { /* filter the UMonad that was added in "ExtendOutputForUpdateState" */
271 if (node_set.insert(input).second) {
272 inputs.push_back(input);
273 }
274 }
275 }
276
277 if (inputs.size() < cnode->size()) {
278 cnode->set_inputs(inputs);
279 changed = true;
280 }
281 }
282 if (changed) {
283 auto mng = func_graph->manager();
284 MS_EXCEPTION_IF_NULL(mng);
285 mng->RemoveRoots();
286 mng->KeepRoots({func_graph});
287 }
288 return changed;
289 }
290 } // namespace opt
291 } // namespace mindspore
292