• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONST_OUTPUT_ELIMINATE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONST_OUTPUT_ELIMINATE_H_
19 
20 #include <memory>
21 #include <vector>
22 #include "ir/anf.h"
23 #include "frontend/optimizer/optimizer.h"
24 #include "frontend/optimizer/anf_visitor.h"
25 #include "frontend/optimizer/irpass.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "include/common/utils/anfalgo.h"
28 
29 namespace mindspore::opt::irpass {
30 // {a=makeTule(0, 0, 0);return a;} --> {a=makeTuple(0,0,0); b=depend(0, a); return b;}
31 // {a=makeTule(0, 0, 0, grad);return a;} --> {a=makeTuple(0,0,0);b=depend(0, a); c=makeTuple(b, grad); return c;}
32 class ConstOutputEliminater : public AnfVisitor {
33  public:
operator()34   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
35     Reset();
36     auto flag = IsEliminate(node);
37     if (!flag) {
38       return nullptr;
39     }
40 
41     MS_LOG(INFO) << "const output eliminater process";
42 
43     auto fg = GetValueNode<FuncGraphPtr>(node);
44     auto output = fg->output();
45     const size_t min_input_size = 3;
46     const auto &inputs = output->cast<CNodePtr>()->inputs();
47     if (inputs.size() < min_input_size) {
48       MS_LOG(INFO) << "maketuple input size small, size=" << inputs.size();
49       return nullptr;
50     }
51 
52     if (!grad_mode_) {
53       const auto const_data = Tensor0Builder();
54       new_out_abstract_ = const_data->ToAbstract();
55       auto new_value_node = NewValueNode(const_data);
56       new_value_node->set_abstract(new_out_abstract_);
57 
58       auto depend = fg->NewCNode({NewValueNode(prim::kPrimDepend), new_value_node, output});
59       MS_EXCEPTION_IF_NULL(depend);
60       depend->set_abstract(new_out_abstract_);
61       fg->set_output(depend);
62     } else {
63       // Zeros + grad
64       std::vector<AnfNodePtr> zero_inputs(inputs.begin() + 1, inputs.end() - 1);
65       auto grad_input = inputs.back();
66 
67       std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
68       make_tuple_inputs.insert(make_tuple_inputs.end(), zero_inputs.begin(), zero_inputs.end());
69       auto tuple_zero_node_abstract = GetTupleAbstract(zero_inputs);
70       auto tuple_zero_node = fg->NewCNode(make_tuple_inputs);
71       tuple_zero_node->set_abstract(tuple_zero_node_abstract);
72 
73       const auto const_data = Tensor0Builder();
74       auto abstract_tensor = const_data->ToAbstract();
75       auto new_value_node = NewValueNode(const_data);
76       new_value_node->set_abstract(abstract_tensor);
77       auto depend = fg->NewCNode({NewValueNode(prim::kPrimDepend), new_value_node, tuple_zero_node});
78       depend->set_abstract(abstract_tensor);
79 
80       new_out_abstract_ = GetTupleAbstract({new_value_node, grad_input});
81       auto new_out = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), depend, grad_input});
82       new_out->set_abstract(new_out_abstract_);
83       fg->manager()->Replace(output, new_out);
84     }
85     fg->return_node()->set_abstract(new_out_abstract_);
86 
87     (void)DoProcess(fg, true);
88 
89     return nullptr;
90   }
91 
92  private:
93   bool grad_mode_ = false;
94   size_t grad_index_ = 0;
95   AbstractBasePtr new_out_abstract_ = nullptr;
96 
Reset()97   void Reset() {
98     grad_mode_ = false;
99     grad_index_ = 0;
100     new_out_abstract_ = nullptr;
101   }
102 
GetTupleAbstract(const std::vector<AnfNodePtr> & inputs)103   AbstractBasePtr GetTupleAbstract(const std::vector<AnfNodePtr> &inputs) const {
104     AbstractBasePtrList new_sep_abstracts;
105     for (const auto &input : inputs) {
106       new_sep_abstracts.push_back(input->abstract());
107     }
108 
109     return std::make_shared<abstract::AbstractTuple>(new_sep_abstracts);
110   }
111 
IsTupleAllZero(const AnfNodePtr & node)112   bool IsTupleAllZero(const AnfNodePtr &node) {
113     auto tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
114     if (tuple == nullptr) {
115       return false;
116     }
117     size_t element_cnt = 0;
118     for (const auto &element : tuple->elements()) {
119       element_cnt++;
120       if (element->isa<abstract::AbstractTensor>()) {
121         const auto &tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
122         MS_EXCEPTION_IF_NULL(tensor_abstract);
123         auto dim_zero = tensor_abstract->BuildShape()->IsDimZero();
124         auto value_any = tensor_abstract->BuildValue()->isa<ValueAny>();
125         if (!value_any) {
126           return false;
127         }
128 
129         if (element_cnt == tuple->elements().size()) {
130           grad_mode_ = dim_zero ? false : true;
131           grad_index_ = tuple->elements().size() - 1;
132           continue;
133         }
134 
135         if (!dim_zero) {
136           return false;
137         }
138 
139         continue;
140       }
141 
142       if (!element->isa<abstract::AbstractScalar>()) {
143         return false;
144       }
145       const auto &scalar_abstract = element->cast<abstract::AbstractScalarPtr>();
146       MS_EXCEPTION_IF_NULL(scalar_abstract);
147       auto abs_value = scalar_abstract->BuildValue();
148       MS_EXCEPTION_IF_NULL(abs_value);
149       auto abs_int32 = dyn_cast<Int32Imm>(abs_value);
150       if (abs_int32 != nullptr) {
151         if (abs_int32->value() != 0) {
152           return false;
153         }
154         continue;
155       }
156 
157       auto abs_int64 = dyn_cast<Int64Imm>(abs_value);
158       if (abs_int64 == nullptr) {
159         return false;
160       }
161 
162       if (abs_int64->value() != 0) {
163         return false;
164       }
165     }
166 
167     return true;
168   }
169 
IsEliminate(const AnfNodePtr & node)170   bool IsEliminate(const AnfNodePtr &node) {
171     auto fg = GetValueNode<FuncGraphPtr>(node);
172     if (fg == nullptr) {
173       return false;
174     }
175     auto output = fg->output();
176     if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
177       return false;
178     }
179 
180     // Check whether the output is 0
181     if (!IsTupleAllZero(output)) {
182       return false;
183     }
184 
185     // Check output users
186     return DoProcess(fg);
187   }
188 
189   bool DoProcess(const FuncGraphPtr &func, bool is_replace = false) const {
190     MS_EXCEPTION_IF_NULL(func);
191     auto &fg_use_map = func->func_graph_cnodes_index();
192     if (fg_use_map.empty()) {
193       return false;
194     }
195 
196     for (auto &fg_use : fg_use_map) {
197       auto use_node = fg_use.first->first->cast<CNodePtr>();
198       if (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple)) {
199         return false;
200       }
201       auto use_node_graph = use_node->func_graph();
202       auto &fg_use_map_sub = use_node_graph->func_graph_cnodes_index();
203       auto mng_sub = use_node_graph->manager();
204       for (auto &fg_use_sub : fg_use_map_sub) {
205         auto fg_use_node = fg_use_sub.first->first->cast<CNodePtr>();
206         if (fg_use_node == nullptr) {
207           return false;
208         }
209         auto users_sub = mng_sub->node_users()[fg_use_node];
210 
211         auto ret = SubUsersProcess(users_sub, is_replace);
212         if (!ret) {
213           return false;
214         }
215       }
216     }
217 
218     return true;
219   }
220 
SubUsersProcess(const AnfNodeIndexSet & users,bool is_replace)221   bool SubUsersProcess(const AnfNodeIndexSet &users, bool is_replace) const {
222     for (auto &user : users) {
223       if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kDependAttachNodeIndex) {
224         continue;
225       }
226 
227       if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
228         return false;
229       }
230 
231       auto index = common::AnfAlgo::GetTupleGetItemOutIndex(user.first->cast<CNodePtr>());
232       if (index != kIndex1) {
233         continue;
234       }
235 
236       auto mng_sub = user.first->func_graph()->manager();
237       auto users_sub = mng_sub->node_users()[user.first];
238       for (auto &user_sub : users_sub) {
239         if (is_replace) {
240           user_sub.first->set_abstract(new_out_abstract_);
241         }
242 
243         auto ret = ConstNodeRealUserProcess(user_sub.first, user_sub.first->func_graph(), is_replace);
244         if (!ret) {
245           return false;
246         }
247       }
248     }
249 
250     return true;
251   }
252 
ConstNodeRealUserProcess(const AnfNodePtr & node,const FuncGraphPtr & func,bool is_replace)253   bool ConstNodeRealUserProcess(const AnfNodePtr &node, const FuncGraphPtr &func, bool is_replace) const {
254     MS_EXCEPTION_IF_NULL(node);
255     MS_EXCEPTION_IF_NULL(func);
256 
257     auto mng = func->manager();
258     auto users = mng->node_users()[node];
259     if (users.empty()) {
260       return false;
261     }
262 
263     for (auto &user : users) {
264       if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kDependAttachNodeIndex) {
265         continue;
266       }
267 
268       if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
269         return false;
270       }
271 
272       if (!is_replace) {
273         // Check
274         auto ret = RealUserCallerCheck(user.first, user.first->func_graph());
275         if (!ret) {
276           return false;
277         }
278       }
279 
280       if (is_replace) {
281         // Real caller
282         if (!grad_mode_) {
283           mng->Replace(user.first, node);
284         } else {
285           auto index = common::AnfAlgo::GetTupleGetItemOutIndex(user.first->cast<CNodePtr>());
286           auto real_input = common::AnfAlgo::GetTupleGetItemRealInput(user.first->cast<CNodePtr>());
287           size_t new_index = index == grad_index_ ? 1 : 0;
288           auto new_index_value = NewValueNode(MakeValue(SizeToLong(new_index)));
289           auto new_node = func->NewCNode({NewValueNode(prim::kPrimTupleGetItem), real_input, new_index_value});
290           new_node->set_abstract(user.first->abstract());
291           mng->Replace(user.first, new_node);
292         }
293       }
294     }
295 
296     return true;
297   }
298 
Tensor0Builder()299   tensor::TensorPtr Tensor0Builder() const { return std::make_shared<tensor::Tensor>(0.0); }
300 
RealUserCallerCheck(const AnfNodePtr & node,const FuncGraphPtr & func)301   bool RealUserCallerCheck(const AnfNodePtr &node, const FuncGraphPtr &func) const {
302     MS_EXCEPTION_IF_NULL(node);
303     MS_EXCEPTION_IF_NULL(func);
304 
305     auto mng = func->manager();
306     auto &users = mng->node_users()[node];
307 
308     if (users.empty()) {
309       return false;
310     }
311 
312     for (auto &user : users) {
313       if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kDependAttachNodeIndex) {
314         continue;
315       }
316 
317       if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kRealInputIndexInDepend && grad_mode_) {
318         continue;
319       }
320 
321       if (IsPrimitiveCNode(user.first, prim::kPrimSend) && grad_mode_) {
322         continue;
323       }
324 
325       if (!IsPrimitiveCNode(user.first, prim::kPrimMakeTuple)) {
326         return false;
327       }
328 
329       auto tuple = user.first->abstract()->cast<abstract::AbstractTuplePtr>();
330       if (!tuple) {
331         return false;
332       }
333 
334       // Check whether the element of tuple is empty tensor
335       for (const auto &element : tuple->elements()) {
336         if (!element->isa<abstract::AbstractTensor>()) {
337           return false;
338         }
339 
340         const auto &tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
341         MS_EXCEPTION_IF_NULL(tensor_abstract);
342         if (!(tensor_abstract->BuildShape()->IsDimZero() && tensor_abstract->BuildValue()->isa<ValueAny>())) {
343           return false;
344         }
345       }
346     }
347 
348     return true;
349   }
350 };
351 }  // namespace mindspore::opt::irpass
352 
353 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONST_OUTPUT_ELIMINATE_H_
354