1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/pass/common_subexpression_elimination.h"
17
18 #include <map>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include "include/backend/kernel_info.h"
23 #include "include/backend/optimizer/helper.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/utils.h"
26 #include "ops/array_op_name.h"
27 #include "ops/framework_ops.h"
28 #include "ops/sequence_ops.h"
29 #include "utils/ms_context.h"
30
31 namespace mindspore {
32 namespace opt {
33 namespace {
34 using KernelWithIndex = std::pair<AnfNodePtr, int64_t>;
35
CheckIgnoreCase(const AnfNodePtr & node)36 bool CheckIgnoreCase(const AnfNodePtr &node) {
37 MS_EXCEPTION_IF_NULL(node);
38 if (common::AnfAlgo::GetCNodeName(node) != kTransDataOpName) {
39 return false;
40 }
41 auto cnode = node->cast<CNodePtr>();
42 MS_EXCEPTION_IF_NULL(cnode);
43 bool need_ignore = true;
44 auto input_size = cnode->size() - 1;
45 for (size_t k = 0; k < input_size; ++k) {
46 auto input = common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(cnode, k), 0).first;
47 if (input != nullptr && input->isa<CNode>()) {
48 need_ignore = false;
49 break;
50 }
51 }
52 return need_ignore;
53 }
54
EliminateDuplicatedTupleGetItem(const FuncGraphPtr & graph,const FuncGraphManagerPtr & manager)55 void EliminateDuplicatedTupleGetItem(const FuncGraphPtr &graph, const FuncGraphManagerPtr &manager) {
56 MS_EXCEPTION_IF_NULL(graph);
57 MS_EXCEPTION_IF_NULL(manager);
58
59 // key: (getitem_input, getitem_index), value: getitem_list
60 std::map<KernelWithIndex, std::vector<AnfNodePtr>> getitem_dup_map;
61 const auto &node_list = TopoSort(graph->get_return());
62 for (auto &node : node_list) {
63 if (!node->isa<CNode>() || !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
64 continue;
65 }
66 auto getitem_cnode = node->cast<CNodePtr>();
67 MS_EXCEPTION_IF_NULL(getitem_cnode);
68 KernelWithIndex input_with_index{getitem_cnode->input(kRealInputNodeIndexInTupleGetItem),
69 GetGetitemIndex(getitem_cnode)};
70 if (getitem_dup_map.count(input_with_index) == 0) {
71 getitem_dup_map.emplace(input_with_index, std::vector<AnfNodePtr>{node});
72 } else {
73 getitem_dup_map[input_with_index].push_back(node);
74 }
75 }
76
77 // remove duplicated
78 for (auto &item : getitem_dup_map) {
79 auto &getitem_list = item.second;
80 if (getitem_list.size() > 1) {
81 auto first_getitem = getitem_list[0];
82 std::for_each(getitem_list.begin() + 1, getitem_list.end(), [first_getitem, manager](const AnfNodePtr &getitem) {
83 (void)manager->Replace(getitem, first_getitem);
84 });
85 }
86 }
87 }
88 } // namespace
89
CheckEqualKernelBuildInfo(const AnfNodePtr & main,const AnfNodePtr & node) const90 bool BackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
91 MS_EXCEPTION_IF_NULL(main);
92 MS_EXCEPTION_IF_NULL(node);
93 if (main->isa<CNode>()) {
94 auto main_name = common::AnfAlgo::GetCNodeName(main);
95 if (main_name == prim::kPrimTensorMove->name() || main_name == prim::kPrimMemCpyAsync->name()) {
96 return false;
97 }
98 }
99 auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
100 auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
101 if (main_kernel_info == nullptr && node_kernel_info == nullptr) {
102 return true;
103 }
104 if (main_kernel_info != nullptr && node_kernel_info != nullptr) {
105 return *main_kernel_info == *node_kernel_info;
106 }
107 return false;
108 }
109
CheckEqualCnodeInputs(const AnfNodePtr & main,const AnfNodePtr & node) const110 bool BackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const {
111 MS_EXCEPTION_IF_NULL(main);
112 MS_EXCEPTION_IF_NULL(node);
113 auto c_main = main->cast<CNodePtr>();
114 MS_EXCEPTION_IF_NULL(c_main);
115 auto c_node = node->cast<CNodePtr>();
116 MS_EXCEPTION_IF_NULL(c_node);
117 const auto &inp1 = c_main->inputs();
118 const auto &inp2 = c_node->inputs();
119 if (inp1.size() != inp2.size()) {
120 return false;
121 }
122 for (size_t j = 0; j < inp1.size(); j++) {
123 auto inp1_j = GetReplicatedNode(inp1[j]);
124 auto inp2_j = GetReplicatedNode(inp2[j]);
125 MS_EXCEPTION_IF_NULL(inp1_j);
126 MS_EXCEPTION_IF_NULL(inp2_j);
127 if ((inp1_j == inp2_j) || (*inp1_j == *inp2_j)) {
128 continue;
129 }
130 // Handle the case of two different Tensor, but with the same value.
131 if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
132 auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
133 auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
134 if (tensor1->ValueEqual(*tensor2)) {
135 continue;
136 }
137 }
138 return false;
139 }
140 return true;
141 }
142
CheckValueNode(const ValueNodePtr & main,const ValueNodePtr & node) const143 bool BackendCSE::CheckValueNode(const ValueNodePtr &main, const ValueNodePtr &node) const {
144 MS_EXCEPTION_IF_NULL(main);
145 MS_EXCEPTION_IF_NULL(node);
146
147 auto main_value = main->value();
148 MS_EXCEPTION_IF_NULL(main_value);
149 auto node_value = node->value();
150 MS_EXCEPTION_IF_NULL(node_value);
151 if (main_value->isa<Primitive>() && node_value->isa<Primitive>()) {
152 return false;
153 } else if (main_value->isa<tensor::Tensor>() && node_value->isa<tensor::Tensor>()) {
154 auto main_tensor = main_value->cast<tensor::TensorPtr>();
155 auto node_tensor = node_value->cast<tensor::TensorPtr>();
156 return (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node) &&
157 main_tensor->device_address() == node_tensor->device_address();
158 }
159 return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
160 }
161
CheckCNode(const CNodePtr & main,const CNodePtr & node)162 bool BackendCSE::CheckCNode(const CNodePtr &main, const CNodePtr &node) {
163 MS_EXCEPTION_IF_NULL(main);
164 MS_EXCEPTION_IF_NULL(node);
165
166 auto context_ptr = MsContext::GetInstance();
167 MS_EXCEPTION_IF_NULL(context_ptr);
168 if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && CheckIgnoreCase(main)) {
169 return false;
170 }
171 if (HasHiddenSideEffect(main) || HasHiddenSideEffect(node)) {
172 return false;
173 }
174 if (!CheckEqualKernelBuildInfo(main, node)) {
175 return false;
176 }
177 return CheckEqualCnodeInputs(main, node);
178 }
179
CheckReplace(const AnfNodePtr & main,const AnfNodePtr & node)180 bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) {
181 MS_EXCEPTION_IF_NULL(main);
182 MS_EXCEPTION_IF_NULL(node);
183
184 // attrs of nop node inserted by backend maybe omitted, so two nodes have same inputs will have different outputs
185 auto main_abs = main->abstract();
186 auto node_abs = node->abstract();
187 if (main_abs != nullptr && node_abs != nullptr && !(*main_abs == *node_abs)) {
188 return false;
189 }
190
191 if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
192 return CheckValueNode(main->cast<ValueNodePtr>(), node->cast<ValueNodePtr>());
193 } else if (main->isa<CNode>() && node->isa<CNode>()) {
194 return CheckCNode(main->cast<CNodePtr>(), node->cast<CNodePtr>());
195 }
196 return false;
197 }
198
Cse(const FuncGraphPtr graph,const FuncGraphManagerPtr manager)199 bool BackendCSE::Cse(const FuncGraphPtr graph, const FuncGraphManagerPtr manager) {
200 MS_EXCEPTION_IF_NULL(manager);
201 Init();
202 auto ret = BuildOrderGroupForOneGraph(graph);
203 if (ret) {
204 DoReplace(manager);
205 EliminateDuplicatedTupleGetItem(graph, manager);
206 }
207 return ret;
208 }
209
Run(const FuncGraphPtr & func_graph)210 bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
211 MS_EXCEPTION_IF_NULL(func_graph);
212 auto backend_cse = std::make_shared<BackendCSE>();
213 MS_EXCEPTION_IF_NULL(backend_cse);
214 return backend_cse->Cse(func_graph, func_graph->manager());
215 }
216 } // namespace opt
217 } // namespace mindspore
218