• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/update_state_formatter.h"
17 
18 #include <vector>
19 #include <set>
20 #include <memory>
21 #include <utility>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/anf.h"
26 #include "ir/named.h"
27 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
28 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
29 #include "backend/common/graph_kernel/core/eliminate_redundant_output.h"
30 
31 namespace mindspore::graphkernel {
IsUpdateState(const std::pair<AnfNodePtr,int> & user)32 bool IsUpdateState(const std::pair<AnfNodePtr, int> &user) {
33   return IsPrimitiveCNode(user.first, prim::kPrimUpdateState) ||
34          (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second > 1);
35 }
36 
GetUpdateStateList(const FuncGraphPtr & func_graph)37 AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
38   auto todos = TopoSort(func_graph->get_return());
39   AnfNodePtrList result;
40   (void)std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
41     return IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimDepend);
42   });
43   return result;
44 }
45 
ExtendInputsOfUpdateState(const AnfNodePtrList & nodes,const FuncGraphPtr & func_graph) const46 AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
47                                                             const FuncGraphPtr &func_graph) const {
48   AnfNodePtrList result;
49   for (auto node : nodes) {
50     if (node->abstract()->isa<abstract::AbstractTuple>()) {
51       auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
52       auto num = node_abstract.size();
53       for (size_t i = 0; i < num; i++) {
54         auto idx_val = SizeToLong(i);
55 
56         auto idx = NewValueNode(idx_val);
57         MS_EXCEPTION_IF_NULL(idx);
58         idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
59 
60         auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
61         MS_EXCEPTION_IF_NULL(tuple_getitem);
62         tuple_getitem->set_abstract(node_abstract[i]);
63         Callback::Instance()->SetEmptyKernelInfo(tuple_getitem);
64         result.push_back(tuple_getitem);
65       }
66     } else {
67       result.push_back(node);
68     }
69   }
70   return result;
71 }
72 
Run(const FuncGraphPtr & func_graph)73 bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
74   auto todos = GetUpdateStateList(func_graph);
75   bool changed = false;
76   auto mng = func_graph->manager();
77   MS_EXCEPTION_IF_NULL(mng);
78   for (auto node : todos) {
79     auto cnode = node->cast<CNodePtr>();
80     MS_EXCEPTION_IF_NULL(cnode);
81     if (cnode->size() <= kUpdateStateRealInput) {
82       continue;
83     }
84     auto inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
85     // extend inputs of UpdateState if which have multiple outputs
86     inputs = ExtendInputsOfUpdateState(inputs, func_graph);
87     if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
88       AnfNodePtrList node_inputs = {cnode->input(kAnfPrimitiveIndex), cnode->input(kUpdateStateStateInput)};
89       (void)node_inputs.insert(node_inputs.cend(), inputs.cbegin(), inputs.cend());
90       // Create a new UpdateState
91       auto new_node = func_graph->NewCNode(node_inputs);
92       new_node->set_abstract(node->abstract());
93       (void)mng->Replace(node, new_node);
94       changed = true;
95     }
96   }
97   return changed;
98 }
99 
Run(const FuncGraphPtr & func_graph)100 bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
101   auto todos = GetUpdateStateList(func_graph);
102   bool changed = false;
103   auto mng = func_graph->manager();
104   MS_EXCEPTION_IF_NULL(mng);
105   for (const auto &node : todos) {
106     auto cnode = node->cast<CNodePtr>();
107     MS_EXCEPTION_IF_NULL(cnode);
108     if (cnode->size() <= kUpdateStateRealInput + 1) {
109       continue;
110     }
111     AnfNodePtrList mt_inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
112     AbstractBasePtrList abs_list;
113     (void)std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
114                          [](const AnfNodePtr &inp) { return inp->abstract(); });
115     (void)mt_inputs.insert(mt_inputs.cbegin(), NewValueNode(prim::kPrimMakeTuple));
116     auto mt_node = func_graph->NewCNode(mt_inputs);
117     mt_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
118     Callback::Instance()->SetEmptyKernelInfo(mt_node);
119 
120     AnfNodePtrList inputs = {cnode->input(0), cnode->input(1), mt_node};
121     auto new_node = func_graph->NewCNode(inputs);
122     new_node->set_abstract(node->abstract());
123     Callback::Instance()->SetEmptyKernelInfo(new_node);
124     (void)mng->Replace(node, new_node);
125     changed = true;
126   }
127   return changed;
128 }
129 
Run(const FuncGraphPtr & func_graph)130 bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
131   auto todos = FindGraphKernelsWithMultiOutput(func_graph);
132   auto mng = func_graph->manager();
133   MS_EXCEPTION_IF_NULL(mng);
134   bool changed = false;
135   for (const auto &node : todos) {
136     (void)GetGraphKernelGetitemList(mng, node, &getitems_, false);
137     if (getitems_.empty()) {
138       continue;
139     }
140     FindIndexesToUpdateState(mng);
141     if (indexes_.empty()) {
142       continue;
143     }
144     auto sub_func_graph = GetCNodeFuncGraph(node);
145     FilterIndexes(sub_func_graph);
146     if (indexes_.empty()) {
147       continue;
148     }
149     for (auto idx : indexes_) {
150       changed = ProcessIndex(func_graph, sub_func_graph, idx) || changed;
151     }
152   }
153   if (changed) {
154     GkUtils::UpdateFuncGraphManager(mng, func_graph);
155     auto spread_update_state = std::make_shared<SpreadUpdateState>();
156     MS_EXCEPTION_IF_NULL(spread_update_state);
157     (void)spread_update_state->Run(func_graph);
158     auto elim_hanging_output = std::make_shared<EliminateHangingOutput>();
159     MS_EXCEPTION_IF_NULL(elim_hanging_output);
160     (void)elim_hanging_output->Run(func_graph);
161   }
162   return changed;
163 }
164 
FindIndexesToUpdateState(const FuncGraphManagerPtr & mng)165 void ExtendOutputForUpdateState::FindIndexesToUpdateState(const FuncGraphManagerPtr &mng) {
166   indexes_.clear();
167   external_user_type_.clear();
168   external_user_type_.resize(getitems_.size(), ExternalUserType::kNormalOp);
169   for (size_t i = 0; i < getitems_.size(); ++i) {
170     const AnfNodePtr &getitem = getitems_[i];
171     if (getitem == nullptr) {
172       continue;
173     }
174 
175     const auto &getitem_user = mng->node_users()[getitem];
176     if (std::all_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
177       external_user_type_[i] = ExternalUserType::kUpdateState;
178       indexes_.push_back(i);
179     } else if (std::any_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
180       external_user_type_[i] = ExternalUserType::kMix;
181       indexes_.push_back(i);
182     }
183   }
184 }
185 
FilterIndexes(const FuncGraphPtr & func_graph)186 void ExtendOutputForUpdateState::FilterIndexes(const FuncGraphPtr &func_graph) {
187   auto output_node = func_graph->output()->cast<CNodePtr>();
188   // do not process the side-effect nodes.
189   (void)indexes_.erase(std::remove_if(indexes_.begin(), indexes_.end(),
190                                       [&output_node](size_t i) { return IsSideEffectNode(output_node->input(i + 1)); }),
191                        indexes_.cend());
192 }
193 
FindAllOutputs(const FuncGraphPtr & func_graph,size_t index)194 std::vector<size_t> ExtendOutputForUpdateState::FindAllOutputs(const FuncGraphPtr &func_graph, size_t index) {
195   auto output_node = func_graph->output()->cast<CNodePtr>();
196   auto index_node = output_node->input(index);
197   std::vector<size_t> group;
198 
199   // if the `out_node` is a user (direct or indirect) of the `index_node`, returns true
200   auto DependsOnIndexNode = [&index_node](const AnfNodePtr &out_node) -> bool {
201     bool result = false;
202     auto IncludeFunc = [&result, &index_node](const AnfNodePtr &node) {
203       if (node == index_node) {
204         result = true;
205         return EXCLUDE;
206       }
207       return result ? EXCLUDE : FOLLOW;
208     };
209     static_cast<void>(DeepLinkedGraphSearch(out_node, IncludeFunc));
210     return result;
211   };
212 
213   for (size_t i = 1; i < output_node->size(); i++) {
214     auto out = output_node->input(i);
215     // only process the nodes that depend on index_node.
216     if (!DependsOnIndexNode(out)) {
217       continue;
218     }
219 
220     // 1. always extend to the side-effect nodes
221     // 2. if the external users are only UpdateState, the related output will be eliminated,
222     //    so only the getitem with realkernel user can be extended to.
223     if (IsSideEffectNode(out) ||
224         (getitems_[i - 1] != nullptr && external_user_type_[i - 1] != ExternalUserType::kUpdateState)) {
225       group.push_back(i - 1);
226     }
227   }
228   return group;
229 }
230 
ProcessIndex(const FuncGraphPtr & func_graph,const FuncGraphPtr & sub_func_graph,size_t index)231 bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph,
232                                               size_t index) {
233   auto group = FindAllOutputs(sub_func_graph, index + 1);
234   AnfNodePtr new_node = nullptr;
235   if (group.size() == 1 && group[0] == index) {
236     return false;
237   }
238   if (group.empty()) {
239     // the output is not side-effect node, but it doesn't have realkernel user.
240     // replace the getitem with a value node that is unrelated to the original node,
241     // here a value node with value None is used,
242     // this value node will be removed in later pass(MergeOutputForUpdateState).
243     MS_LOG(INFO) << "The " << getitems_[index]->fullname_with_scope() << " only has UpdateState user.";
244     new_node = NewValueNode(kNone)->cast<AnfNodePtr>();
245     new_node->set_abstract(kNone->ToAbstract());
246   } else {
247     // Create MakeTuple, even though the group size is 1, the following pass will spread the MakeTuple,
248     // so it's unnecessary to set abstract for it.
249     AnfNodePtrList mt_input = {NewValueNode(prim::kPrimMakeTuple)};
250     (void)std::transform(group.begin(), group.end(), std::back_inserter(mt_input),
251                          [this](size_t idx) { return getitems_[idx]; });
252     new_node = func_graph->NewCNode(mt_input)->cast<AnfNodePtr>();
253   }
254   auto mng = func_graph->manager();
255   MS_EXCEPTION_IF_NULL(mng);
256   for (auto user : mng->node_users()[getitems_[index]]) {
257     if (IsUpdateState(user)) {
258       user.first->cast<CNodePtr>()->set_input(IntToSize(user.second), new_node);
259     }
260   }
261   return true;
262 }
263 
Run(const FuncGraphPtr & func_graph)264 bool MergeOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
265   constexpr size_t min_input_num = 2;
266   auto todos = GetUpdateStateList(func_graph);
267   bool changed = false;
268   for (auto node : todos) {
269     auto cnode = node->cast<CNodePtr>();
270     MS_EXCEPTION_IF_NULL(cnode);
271     AnfNodePtrList inputs = {cnode->input(0), cnode->input(1)};
272     std::set<AnfNodePtr> node_set;
273     for (size_t i = min_input_num; i < cnode->size(); ++i) {
274       auto input = cnode->input(i);
275       if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
276         // only keep one GetItem for that link to the same node.
277         auto gt_input = input->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
278         if (node_set.insert(gt_input).second) {
279           inputs.push_back(input);
280         }
281       } else {
282         if (input->isa<ValueNode>()) {
283           auto value_node = input->cast<ValueNodePtr>();
284           auto value = value_node->value();
285           // filter the None valuenode generated by "ExtendOutputForUpdateState"
286           if (value->isa<None>()) {
287             continue;
288           }
289         }
290         if (node_set.insert(input).second) {
291           inputs.push_back(input);
292         }
293       }
294     }
295 
296     if (inputs.size() == min_input_num) {
297       inputs.push_back(inputs.back());
298       cnode->set_inputs(inputs);
299       changed = true;
300     } else if (inputs.size() < cnode->size()) {
301       cnode->set_inputs(inputs);
302       changed = true;
303     }
304   }
305   if (changed) {
306     auto mng = func_graph->manager();
307     MS_EXCEPTION_IF_NULL(mng);
308     mng->RemoveRoots();
309     mng->KeepRoots({func_graph});
310   }
311   return changed;
312 }
313 }  // namespace mindspore::graphkernel
314