• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/pass/common_subexpression_elimination.h"
17 
18 #include <map>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include "include/backend/kernel_info.h"
23 #include "include/backend/optimizer/helper.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/utils.h"
26 #include "ops/array_op_name.h"
27 #include "ops/framework_ops.h"
28 #include "ops/sequence_ops.h"
29 #include "utils/ms_context.h"
30 
31 namespace mindspore {
32 namespace opt {
33 namespace {
34 using KernelWithIndex = std::pair<AnfNodePtr, int64_t>;
35 
CheckIgnoreCase(const AnfNodePtr & node)36 bool CheckIgnoreCase(const AnfNodePtr &node) {
37   MS_EXCEPTION_IF_NULL(node);
38   if (common::AnfAlgo::GetCNodeName(node) != kTransDataOpName) {
39     return false;
40   }
41   auto cnode = node->cast<CNodePtr>();
42   MS_EXCEPTION_IF_NULL(cnode);
43   bool need_ignore = true;
44   auto input_size = cnode->size() - 1;
45   for (size_t k = 0; k < input_size; ++k) {
46     auto input = common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(cnode, k), 0).first;
47     if (input != nullptr && input->isa<CNode>()) {
48       need_ignore = false;
49       break;
50     }
51   }
52   return need_ignore;
53 }
54 
EliminateDuplicatedTupleGetItem(const FuncGraphPtr & graph,const FuncGraphManagerPtr & manager)55 void EliminateDuplicatedTupleGetItem(const FuncGraphPtr &graph, const FuncGraphManagerPtr &manager) {
56   MS_EXCEPTION_IF_NULL(graph);
57   MS_EXCEPTION_IF_NULL(manager);
58 
59   // key: (getitem_input, getitem_index), value: getitem_list
60   std::map<KernelWithIndex, std::vector<AnfNodePtr>> getitem_dup_map;
61   const auto &node_list = TopoSort(graph->get_return());
62   for (auto &node : node_list) {
63     if (!node->isa<CNode>() || !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
64       continue;
65     }
66     auto getitem_cnode = node->cast<CNodePtr>();
67     MS_EXCEPTION_IF_NULL(getitem_cnode);
68     KernelWithIndex input_with_index{getitem_cnode->input(kRealInputNodeIndexInTupleGetItem),
69                                      GetGetitemIndex(getitem_cnode)};
70     if (getitem_dup_map.count(input_with_index) == 0) {
71       getitem_dup_map.emplace(input_with_index, std::vector<AnfNodePtr>{node});
72     } else {
73       getitem_dup_map[input_with_index].push_back(node);
74     }
75   }
76 
77   // remove duplicated
78   for (auto &item : getitem_dup_map) {
79     auto &getitem_list = item.second;
80     if (getitem_list.size() > 1) {
81       auto first_getitem = getitem_list[0];
82       std::for_each(getitem_list.begin() + 1, getitem_list.end(), [first_getitem, manager](const AnfNodePtr &getitem) {
83         (void)manager->Replace(getitem, first_getitem);
84       });
85     }
86   }
87 }
88 }  // namespace
89 
CheckEqualKernelBuildInfo(const AnfNodePtr & main,const AnfNodePtr & node) const90 bool BackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
91   MS_EXCEPTION_IF_NULL(main);
92   MS_EXCEPTION_IF_NULL(node);
93   if (main->isa<CNode>()) {
94     auto main_name = common::AnfAlgo::GetCNodeName(main);
95     if (main_name == prim::kPrimTensorMove->name() || main_name == prim::kPrimMemCpyAsync->name()) {
96       return false;
97     }
98   }
99   auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
100   auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
101   if (main_kernel_info == nullptr && node_kernel_info == nullptr) {
102     return true;
103   }
104   if (main_kernel_info != nullptr && node_kernel_info != nullptr) {
105     return *main_kernel_info == *node_kernel_info;
106   }
107   return false;
108 }
109 
CheckEqualCnodeInputs(const AnfNodePtr & main,const AnfNodePtr & node) const110 bool BackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const {
111   MS_EXCEPTION_IF_NULL(main);
112   MS_EXCEPTION_IF_NULL(node);
113   auto c_main = main->cast<CNodePtr>();
114   MS_EXCEPTION_IF_NULL(c_main);
115   auto c_node = node->cast<CNodePtr>();
116   MS_EXCEPTION_IF_NULL(c_node);
117   const auto &inp1 = c_main->inputs();
118   const auto &inp2 = c_node->inputs();
119   if (inp1.size() != inp2.size()) {
120     return false;
121   }
122   for (size_t j = 0; j < inp1.size(); j++) {
123     auto inp1_j = GetReplicatedNode(inp1[j]);
124     auto inp2_j = GetReplicatedNode(inp2[j]);
125     MS_EXCEPTION_IF_NULL(inp1_j);
126     MS_EXCEPTION_IF_NULL(inp2_j);
127     if ((inp1_j == inp2_j) || (*inp1_j == *inp2_j)) {
128       continue;
129     }
130     // Handle the case of two different Tensor, but with the same value.
131     if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
132       auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
133       auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
134       if (tensor1->ValueEqual(*tensor2)) {
135         continue;
136       }
137     }
138     return false;
139   }
140   return true;
141 }
142 
CheckValueNode(const ValueNodePtr & main,const ValueNodePtr & node) const143 bool BackendCSE::CheckValueNode(const ValueNodePtr &main, const ValueNodePtr &node) const {
144   MS_EXCEPTION_IF_NULL(main);
145   MS_EXCEPTION_IF_NULL(node);
146 
147   auto main_value = main->value();
148   MS_EXCEPTION_IF_NULL(main_value);
149   auto node_value = node->value();
150   MS_EXCEPTION_IF_NULL(node_value);
151   if (main_value->isa<Primitive>() && node_value->isa<Primitive>()) {
152     return false;
153   } else if (main_value->isa<tensor::Tensor>() && node_value->isa<tensor::Tensor>()) {
154     auto main_tensor = main_value->cast<tensor::TensorPtr>();
155     auto node_tensor = node_value->cast<tensor::TensorPtr>();
156     return (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node) &&
157            main_tensor->device_address() == node_tensor->device_address();
158   }
159   return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
160 }
161 
CheckCNode(const CNodePtr & main,const CNodePtr & node)162 bool BackendCSE::CheckCNode(const CNodePtr &main, const CNodePtr &node) {
163   MS_EXCEPTION_IF_NULL(main);
164   MS_EXCEPTION_IF_NULL(node);
165 
166   auto context_ptr = MsContext::GetInstance();
167   MS_EXCEPTION_IF_NULL(context_ptr);
168   if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && CheckIgnoreCase(main)) {
169     return false;
170   }
171   if (HasHiddenSideEffect(main) || HasHiddenSideEffect(node)) {
172     return false;
173   }
174   if (!CheckEqualKernelBuildInfo(main, node)) {
175     return false;
176   }
177   return CheckEqualCnodeInputs(main, node);
178 }
179 
CheckReplace(const AnfNodePtr & main,const AnfNodePtr & node)180 bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) {
181   MS_EXCEPTION_IF_NULL(main);
182   MS_EXCEPTION_IF_NULL(node);
183 
184   // attrs of nop node inserted by backend maybe omitted, so two nodes have same inputs will have different outputs
185   auto main_abs = main->abstract();
186   auto node_abs = node->abstract();
187   if (main_abs != nullptr && node_abs != nullptr && !(*main_abs == *node_abs)) {
188     return false;
189   }
190 
191   if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
192     return CheckValueNode(main->cast<ValueNodePtr>(), node->cast<ValueNodePtr>());
193   } else if (main->isa<CNode>() && node->isa<CNode>()) {
194     return CheckCNode(main->cast<CNodePtr>(), node->cast<CNodePtr>());
195   }
196   return false;
197 }
198 
Cse(const FuncGraphPtr graph,const FuncGraphManagerPtr manager)199 bool BackendCSE::Cse(const FuncGraphPtr graph, const FuncGraphManagerPtr manager) {
200   MS_EXCEPTION_IF_NULL(manager);
201   Init();
202   auto ret = BuildOrderGroupForOneGraph(graph);
203   if (ret) {
204     DoReplace(manager);
205     EliminateDuplicatedTupleGetItem(graph, manager);
206   }
207   return ret;
208 }
209 
Run(const FuncGraphPtr & func_graph)210 bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
211   MS_EXCEPTION_IF_NULL(func_graph);
212   auto backend_cse = std::make_shared<BackendCSE>();
213   MS_EXCEPTION_IF_NULL(backend_cse);
214   return backend_cse->Cse(func_graph, func_graph->manager());
215 }
216 }  // namespace opt
217 }  // namespace mindspore
218