• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
17 
18 #include <vector>
19 #include <set>
20 #include <memory>
21 #include <utility>
22 #include <algorithm>
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/kernel_compiler/common_utils.h"
25 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
26 #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
27 
28 namespace mindspore {
29 namespace opt {
GetUpdateStateList(const FuncGraphPtr & func_graph)30 AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
31   auto todos = TopoSort(func_graph->get_return());
32   AnfNodePtrList result;
33   std::copy_if(todos.begin(), todos.end(), std::back_inserter(result),
34                [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimUpdateState); });
35   return result;
36 }
37 
SpreadTuples(const AnfNodePtrList & nodes,size_t begin_index)38 AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
39   AnfNodePtrList result;
40   for (size_t i = begin_index; i < nodes.size(); i++) {
41     if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
42       auto mt = nodes[i]->cast<CNodePtr>();
43       // recursively spread all inner tuples.
44       auto mt_inputs = SpreadTuples(mt->inputs(), 1);
45       result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
46     } else {
47       result.push_back(nodes[i]);
48     }
49   }
50   return result;
51 }
52 
ExtendInputsOfUpdateState(const AnfNodePtrList & nodes,const FuncGraphPtr & func_graph)53 AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
54                                                             const FuncGraphPtr &func_graph) {
55   AnfNodePtrList result;
56   for (auto node : nodes) {
57     if (node->abstract()->isa<abstract::AbstractTuple>()) {
58       auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
59       auto num = node_abstract.size();
60       for (size_t i = 0; i < num; i++) {
61         auto idx_val = SizeToLong(i);
62 
63         auto idx = NewValueNode(idx_val);
64         MS_EXCEPTION_IF_NULL(idx);
65         idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
66 
67         auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
68         MS_EXCEPTION_IF_NULL(tuple_getitem);
69         tuple_getitem->set_abstract(node_abstract[i]);
70         tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>());
71         result.push_back(tuple_getitem);
72       }
73     } else {
74       result.push_back(node);
75     }
76   }
77   return result;
78 }
79 
Run(const FuncGraphPtr & func_graph)80 bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
81   auto todos = GetUpdateStateList(func_graph);
82   bool changed = false;
83   auto mng = func_graph->manager();
84   MS_EXCEPTION_IF_NULL(mng);
85   for (auto node : todos) {
86     auto cnode = node->cast<CNodePtr>();
87     MS_EXCEPTION_IF_NULL(cnode);
88     if (cnode->size() <= kUpdateStateRealInput) continue;
89     auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
90     // extend inputs of UpdateState if which have multiple outputs
91     inputs = ExtendInputsOfUpdateState(inputs, func_graph);
92     if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
93       AnfNodePtrList node_inputs = {cnode->input(kAnfPrimitiveIndex), cnode->input(kUpdateStateStateInput)};
94       node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
95       // Create a new UpdateState
96       auto new_node = func_graph->NewCNode(node_inputs);
97       new_node->set_abstract(node->abstract());
98       (void)mng->Replace(node, new_node);
99       changed = true;
100     }
101   }
102   return changed;
103 }
104 
Run(const FuncGraphPtr & func_graph)105 bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
106   auto todos = GetUpdateStateList(func_graph);
107   bool changed = false;
108   auto mng = func_graph->manager();
109   MS_EXCEPTION_IF_NULL(mng);
110   for (const auto &node : todos) {
111     auto cnode = node->cast<CNodePtr>();
112     MS_EXCEPTION_IF_NULL(cnode);
113     if (cnode->size() <= kUpdateStateRealInput + 1) continue;
114     AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
115     AbstractBasePtrList abs_list;
116     std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
117                    [](const AnfNodePtr &inp) { return inp->abstract(); });
118     mt_inputs.insert(mt_inputs.begin(), NewValueNode(prim::kPrimMakeTuple));
119     auto mt_node = func_graph->NewCNode(mt_inputs);
120     mt_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
121     mt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
122 
123     AnfNodePtrList inputs = {cnode->input(0), cnode->input(1), mt_node};
124     auto new_node = func_graph->NewCNode(inputs);
125     new_node->set_abstract(node->abstract());
126     new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
127     (void)mng->Replace(node, new_node);
128     changed = true;
129   }
130   return changed;
131 }
132 
Run(const FuncGraphPtr & func_graph)133 bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
134   auto todos = FindGraphKernelsWithMultiOutput(func_graph);
135   auto mng = func_graph->manager();
136   MS_EXCEPTION_IF_NULL(mng);
137   bool changed = false;
138   for (const auto &node : todos) {
139     GetGraphKernelGetitemList(mng, node, &getitems_, false);
140     if (getitems_.empty()) continue;
141     FindIndexesToUpdateState(mng);
142     if (indexes_.empty()) continue;
143     auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
144     FilterIndexes(sub_func_graph);
145     if (indexes_.empty()) continue;
146     for (auto idx : indexes_) {
147       changed = ProcessIndex(func_graph, sub_func_graph, idx) || changed;
148     }
149   }
150   if (changed) {
151     std::make_shared<SpreadUpdateState>()->Run(func_graph);
152     std::make_shared<EliminateHangingOutput>()->Run(func_graph);
153   }
154   return changed;
155 }
156 
FindIndexesToUpdateState(const FuncGraphManagerPtr & mng)157 void ExtendOutputForUpdateState::FindIndexesToUpdateState(const FuncGraphManagerPtr &mng) {
158   indexes_.clear();
159   external_user_type_.clear();
160   external_user_type_.resize(getitems_.size(), ExternalUserType::kNormalOp);
161   for (size_t i = 0; i < getitems_.size(); ++i) {
162     const AnfNodePtr &getitem = getitems_[i];
163     if (getitem == nullptr) continue;
164 
165     const auto &getitem_user = mng->node_users()[getitem];
166     auto IsUpdateState = [](const std::pair<AnfNodePtr, int> &user) {
167       return IsPrimitiveCNode(user.first, prim::kPrimUpdateState);
168     };
169     if (std::all_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
170       external_user_type_[i] = ExternalUserType::kUpdateState;
171       indexes_.push_back(i);
172     } else if (std::any_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
173       external_user_type_[i] = ExternalUserType::kMix;
174       indexes_.push_back(i);
175     }
176   }
177 }
178 
FilterIndexes(const FuncGraphPtr & func_graph)179 void ExtendOutputForUpdateState::FilterIndexes(const FuncGraphPtr &func_graph) {
180   auto output_node = func_graph->output()->cast<CNodePtr>();
181   // do not process the side-effect nodes.
182   indexes_.erase(std::remove_if(indexes_.begin(), indexes_.end(),
183                                 [&output_node](size_t i) { return IsSideEffectNode(output_node->input(i + 1)); }),
184                  indexes_.end());
185 }
186 
FindAllOutputs(const FuncGraphPtr & func_graph,size_t index)187 std::vector<size_t> ExtendOutputForUpdateState::FindAllOutputs(const FuncGraphPtr &func_graph, size_t index) {
188   auto output_node = func_graph->output()->cast<CNodePtr>();
189   auto index_node = output_node->input(index);
190   std::vector<size_t> group;
191 
192   // if the `out_node` is a user (direct or indirect) of the `index_node`, returns true
193   auto DependsOnIndexNode = [&index_node](const AnfNodePtr &out_node) -> bool {
194     bool result = false;
195     auto IncludeFunc = [&result, &index_node](const AnfNodePtr &node) {
196       if (node == index_node) {
197         result = true;
198         return EXCLUDE;
199       }
200       return result ? EXCLUDE : FOLLOW;
201     };
202     static_cast<void>(DeepLinkedGraphSearch(out_node, IncludeFunc));
203     return result;
204   };
205 
206   for (size_t i = 1; i < output_node->size(); i++) {
207     auto out = output_node->input(i);
208     // only process the nodes that depend on index_node.
209     if (!DependsOnIndexNode(out)) continue;
210 
211     // 1. always extend to the side-effect nodes
212     // 2. if the external users are only UpdateState, the related output will be eliminated,
213     //    so only the getitem with realkernel user can be extended to.
214     if (IsSideEffectNode(out) ||
215         (getitems_[i - 1] != nullptr && external_user_type_[i - 1] != ExternalUserType::kUpdateState)) {
216       group.push_back(i - 1);
217     }
218   }
219   return group;
220 }
221 
ProcessIndex(const FuncGraphPtr & func_graph,const FuncGraphPtr & sub_func_graph,size_t index)222 bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph,
223                                               size_t index) {
224   auto group = FindAllOutputs(sub_func_graph, index + 1);
225   AnfNodePtr new_node = nullptr;
226   if (group.size() == 1 && group[0] == index) {
227     return false;
228   }
229   if (group.empty()) {
230     // the output is not side-effect node, but it hasn't realkernel user.
231     // replace the getitem with a value node that is unrelated to the original node.
232     // and this value node will be removed at the later pass.
233     MS_LOG(INFO) << "The " << getitems_[index]->fullname_with_scope() << " only has UpdateState user.";
234     new_node = NewValueNode(kUMonad)->cast<AnfNodePtr>();
235     new_node->set_abstract(kUMonad->ToAbstract());
236   } else {
237     // Create MakeTuple, even though the group size is 1, the following pass will spread the MakeTuple,
238     // so it's unnecessary to set abstract for it.
239     AnfNodePtrList mt_input = {NewValueNode(prim::kPrimMakeTuple)};
240     std::transform(group.begin(), group.end(), std::back_inserter(mt_input),
241                    [this](size_t idx) { return getitems_[idx]; });
242     new_node = func_graph->NewCNode(mt_input)->cast<AnfNodePtr>();
243   }
244   auto mng = func_graph->manager();
245   MS_EXCEPTION_IF_NULL(mng);
246   for (auto user : mng->node_users()[getitems_[index]]) {
247     if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
248       user.first->cast<CNodePtr>()->set_input(IntToSize(user.second), new_node);
249     }
250   }
251   return true;
252 }
253 
Run(const FuncGraphPtr & func_graph)254 bool MergeOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
255   auto todos = GetUpdateStateList(func_graph);
256   bool changed = false;
257   for (auto node : todos) {
258     auto cnode = node->cast<CNodePtr>();
259     MS_EXCEPTION_IF_NULL(cnode);
260     AnfNodePtrList inputs = {cnode->input(0), cnode->input(1)};
261     std::set<AnfNodePtr> node_set;
262     for (size_t i = 2; i < cnode->size(); ++i) {
263       auto input = cnode->input(i);
264       if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
265         // only keep one GetItem for that link to the same node.
266         auto gt_input = input->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
267         if (node_set.insert(gt_input).second) {
268           inputs.push_back(input);
269         }
270       } else if (!HasAbstractUMonad(input)) { /* filter the UMonad that was added in "ExtendOutputForUpdateState" */
271         if (node_set.insert(input).second) {
272           inputs.push_back(input);
273         }
274       }
275     }
276 
277     if (inputs.size() < cnode->size()) {
278       cnode->set_inputs(inputs);
279       changed = true;
280     }
281   }
282   if (changed) {
283     auto mng = func_graph->manager();
284     MS_EXCEPTION_IF_NULL(mng);
285     mng->RemoveRoots();
286     mng->KeepRoots({func_graph});
287   }
288   return changed;
289 }
290 }  // namespace opt
291 }  // namespace mindspore
292