• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/pass/optimize_dependence.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include "backend/optimizer/common/helper.h"
23 #include "base/core_ops.h"
24 #include "utils/utils.h"
25 #include "backend/session/kernel_graph.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 
28 namespace mindspore {
29 namespace opt {
30 constexpr auto kSingleInputIndex = 1;
31 constexpr auto kIsolatedDependRealInputIndex = 0;
32 constexpr auto kIsolatedDependVirtualInputIndex = 1;
33 namespace {
CreateNewDependNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<AnfNodePtr> & new_depend_inputs)34 CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
35                              const std::vector<AnfNodePtr> &new_depend_inputs) {
36   MS_EXCEPTION_IF_NULL(func_graph);
37   MS_EXCEPTION_IF_NULL(cnode);
38   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
39   if (kernel_graph == nullptr) {
40     auto new_depend = func_graph->NewCNode(new_depend_inputs);
41     MS_EXCEPTION_IF_NULL(new_depend);
42     new_depend->set_abstract(cnode->abstract());
43     new_depend->set_scope(cnode->scope());
44     return new_depend;
45   }
46   auto new_depend = kernel_graph->NewCNode(cnode);
47   MS_EXCEPTION_IF_NULL(new_depend);
48   new_depend->set_inputs(new_depend_inputs);
49   return new_depend;
50 }
51 
CheckIsolatedVirtualNode(const CNodePtr & cnode)52 CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
53   MS_EXCEPTION_IF_NULL(cnode);
54   if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name() &&
55       AnfAlgo::GetCNodeName(cnode) != prim::kPrimLoad->name()) {
56     return nullptr;
57   }
58   auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
59   if (!HasAbstractMonad(virtual_input_op)) {
60     return nullptr;
61   }
62   auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
63   MS_EXCEPTION_IF_NULL(real_input_op);
64   if (!real_input_op->isa<CNode>()) {
65     return nullptr;
66   }
67   auto real_input_cnode = real_input_op->cast<CNodePtr>();
68   return real_input_cnode;
69 }
70 
EliminateIsolatedVirtualNodeInput(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const CNodePtr & eliminate_node)71 AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
72                                              const CNodePtr &eliminate_node) {
73   MS_EXCEPTION_IF_NULL(func_graph);
74   MS_EXCEPTION_IF_NULL(cnode);
75   MS_EXCEPTION_IF_NULL(eliminate_node);
76   auto replace_node = eliminate_node->input(kSingleInputIndex);
77   std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs();
78   new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node;
79   auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
80   (void)func_graph->manager()->Replace(cnode, new_depend);
81   return new_depend;
82 }
83 
GetReplaceNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node)84 AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
85   MS_EXCEPTION_IF_NULL(func_graph);
86   MS_EXCEPTION_IF_NULL(node);
87   if (!node->isa<CNode>()) {
88     return nullptr;
89   }
90   auto cnode = node->cast<CNodePtr>();
91   MS_EXCEPTION_IF_NULL(cnode);
92   auto replace_cnode = cnode;
93   // Process updatestate and depend as isolated node env.
94   auto isolated_cnode = CheckIsolatedVirtualNode(replace_cnode);
95   if (isolated_cnode != nullptr) {
96     replace_cnode = isolated_cnode;
97   }
98   string op_name = AnfAlgo::GetCNodeName(replace_cnode);
99   // Currently we only eliminate transdata or cast nodes.
100   if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
101     return nullptr;
102   }
103   if (!IsNotRealUsedByOthers(func_graph, replace_cnode)) {
104     return nullptr;
105   }
106   CheckCNodeInputSize(replace_cnode, kSingleInputIndex);
107   if (isolated_cnode != nullptr) {
108     auto new_depend_node = EliminateIsolatedVirtualNodeInput(func_graph, cnode, replace_cnode);
109     return new_depend_node;
110   }
111   return cnode->input(kSingleInputIndex);
112 }
113 
ReplaceMakeTuple(const FuncGraphPtr & func_graph,const CNodePtr & cnode)114 AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
115   MS_EXCEPTION_IF_NULL(func_graph);
116   MS_EXCEPTION_IF_NULL(cnode);
117   if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
118     return nullptr;
119   }
120   std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
121   bool need_update = false;
122   size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
123   for (size_t index = 0; index < input_num; ++index) {
124     auto input = AnfAlgo::GetInputNode(cnode, index);
125     AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
126     // If replace input is not null, it will be the input of the TransData or Cast.
127     if (replace_input == nullptr) {
128       new_make_tuple_inputs.push_back(input);
129       continue;
130     }
131     new_make_tuple_inputs.push_back(replace_input);
132     need_update = true;
133   }
134   if (need_update) {
135     auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
136     CNodePtr new_make_tuple = nullptr;
137     if (kernel_graph == nullptr) {
138       new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs);
139     } else {
140       new_make_tuple = kernel_graph->NewCNode(cnode);
141     }
142     MS_EXCEPTION_IF_NULL(new_make_tuple);
143     new_make_tuple->set_inputs(new_make_tuple_inputs);
144     auto manager = func_graph->manager();
145     MS_EXCEPTION_IF_NULL(manager);
146     manager->Replace(cnode, new_make_tuple);
147     return new_make_tuple;
148   }
149   return nullptr;
150 }
151 }  // namespace
152 
DefinePattern() const153 const BaseRef OptimizeDependence::DefinePattern() const {
154   VarPtr X = std::make_shared<Var>();
155   VarPtr Xs = std::make_shared<SeqVar>();
156   return VectorRef({X, Xs});
157 }
158 
SearchTransDataAndCast(const CNodePtr & cnode)159 std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) {
160   // Search Depend and UpdateState only.
161   if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) {
162     return {};
163   }
164   // Find inputs which is Cast or TransData.
165   std::vector<size_t> result;
166   for (size_t i = 1; i < cnode->size(); ++i) {
167     auto &input = cnode->input(i);
168     if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) ||
169         AnfAlgo::CheckPrimitiveType(input, prim::kPrimTransData) ||
170         AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
171       (void)result.emplace_back(i);
172     }
173   }
174   return result;
175 }
176 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const177 const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
178                                              const EquivPtr &) const {
179   MS_EXCEPTION_IF_NULL(func_graph);
180   MS_EXCEPTION_IF_NULL(node);
181   auto cnode = dyn_cast<CNode>(node);
182   if (cnode == nullptr) {
183     return nullptr;
184   }
185   // Search inputs to be replaced.
186   auto candidate_inputs = SearchTransDataAndCast(cnode);
187   if (candidate_inputs.empty()) {
188     return nullptr;
189   }
190   // Get new nodes which will act as new inputs of Depend or UpdateState.
191   std::vector<AnfNodePtr> new_inputs = cnode->inputs();
192   bool inputs_changed = false;
193   for (auto index : candidate_inputs) {
194     if (index >= new_inputs.size()) {
195       MS_LOG(EXCEPTION) << "Index is out of the size of " << cnode->DebugString() << " inputs.";
196     }
197     auto replace_node = GetConvertNode(func_graph, cnode, index);
198     if (replace_node != nullptr) {
199       new_inputs[index] = replace_node;
200       inputs_changed = true;
201     }
202   }
203   if (!inputs_changed) {
204     return nullptr;
205   }
206   // Create a new Depend node to replace the old one if inputs changed.
207   auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs);
208   (void)func_graph->manager()->Replace(cnode, new_depend);
209   return nullptr;
210 }
211 
GetConvertNode(const FuncGraphPtr & graph,const AnfNodePtr & node,const size_t index) const212 const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
213                                                     const size_t index) const {
214   MS_EXCEPTION_IF_NULL(graph);
215   MS_EXCEPTION_IF_NULL(node);
216   auto depend_cnode = node->cast<CNodePtr>();
217   MS_EXCEPTION_IF_NULL(depend_cnode);
218   auto replacing_node = depend_cnode->input(index);
219   MS_EXCEPTION_IF_NULL(replacing_node);
220   if (!replacing_node->isa<CNode>()) {
221     return nullptr;
222   }
223   auto replacing_cnode = replacing_node->cast<CNodePtr>();
224   MS_EXCEPTION_IF_NULL(replacing_cnode);
225   // Deal with the make_tuple with TransData or Cast inputs.
226   auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode);
227   if (make_tuple_replace_node != nullptr) {
228     return make_tuple_replace_node;
229   }
230   AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode);
231   return replace_node;
232 }
233 }  // namespace opt
234 }  // namespace mindspore
235