1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "backend/session/anf_runtime_algorithm.h"
25 #include "runtime/device/kernel_info.h"
26
27 namespace mindspore {
28 namespace opt {
29 namespace {
IsCNodePrimitveEqual(const CNodePtr & main,const CNodePtr & node,const std::vector<PrimitivePtr> & black_list)30 bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std::vector<PrimitivePtr> &black_list) {
31 auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
32 auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
33 if (main_primitive != nullptr && node_primitive != nullptr) {
34 // Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op
35 // alone can prevent some redundant output case (input -> reshape -> output).
36 if (main_primitive->name() != node_primitive->name() ||
37 std::any_of(black_list.begin(), black_list.end(),
38 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
39 return false;
40 }
41
42 auto main_attrs = main_primitive->attrs();
43 auto node_attrs = node_primitive->attrs();
44
45 std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"};
46 for (auto &attr : exclude_attrs) {
47 main_attrs.erase(attr);
48 node_attrs.erase(attr);
49 }
50
51 if (main_attrs.size() != node_attrs.size()) {
52 return false;
53 }
54
55 auto all = std::all_of(main_attrs.begin(), main_attrs.end(),
56 [&node_attrs](const std::pair<std::string, ValuePtr> &item) -> bool {
57 if (item.second == nullptr) {
58 return false;
59 }
60 auto iter = node_attrs.find(item.first);
61 if (iter == node_attrs.end()) {
62 return false;
63 }
64 return *item.second == *iter->second;
65 });
66 return all;
67 }
68
69 return *main->inputs()[0] == *node->inputs()[0];
70 }
71 } // namespace
72
CheckEqualKernelBuildInfo(const AnfNodePtr & main,const AnfNodePtr & node) const73 bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
74 MS_EXCEPTION_IF_NULL(main);
75 MS_EXCEPTION_IF_NULL(node);
76
77 if (!AnfAlgo::IsNodeInGraphKernel(main)) {
78 return BackendCSE::CheckEqualKernelBuildInfo(main, node);
79 }
80
81 auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
82 auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
83 if (main_kernel_info == nullptr && node_kernel_info == nullptr) {
84 return true;
85 }
86
87 if (main_kernel_info != nullptr && node_kernel_info != nullptr) {
88 auto main_build_info = main_kernel_info->GetMutableSelectKernelBuildInfo();
89 auto node_build_info = node_kernel_info->GetMutableSelectKernelBuildInfo();
90 if (main_build_info == nullptr && node_build_info == nullptr) {
91 return true;
92 }
93
94 if (main_build_info == nullptr || node_build_info == nullptr) {
95 return false;
96 }
97
98 if (main_build_info->processor() != node_build_info->processor()) {
99 return false;
100 }
101
102 return main_build_info->IsSimilarityKernelBuildInfo(*node_build_info);
103 }
104 return false;
105 }
106
CheckEqualCnodeInputs(const AnfNodePtr & main,const AnfNodePtr & node) const107 bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const {
108 auto c_main = main->cast<CNodePtr>();
109 MS_EXCEPTION_IF_NULL(c_main);
110 auto c_node = node->cast<CNodePtr>();
111 MS_EXCEPTION_IF_NULL(c_node);
112
113 if (!AnfAlgo::IsNodeInGraphKernel(c_main)) {
114 return BackendCSE::CheckEqualCnodeInputs(main, node);
115 }
116
117 const auto &inp1 = c_main->inputs();
118 const auto &inp2 = c_node->inputs();
119 if (inp1.size() != inp2.size()) {
120 return false;
121 }
122 for (size_t j = 1; j < inp1.size(); j++) {
123 auto inp1_j = inp1[j];
124 auto inp2_j = inp2[j];
125 MS_EXCEPTION_IF_NULL(inp1_j);
126 MS_EXCEPTION_IF_NULL(inp2_j);
127 if (!(*inp1_j == *inp2_j)) {
128 return false;
129 }
130 }
131 return IsCNodePrimitveEqual(c_main, c_node, black_list_);
132 }
133
Run(const FuncGraphPtr & func_graph)134 bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) {
135 MS_EXCEPTION_IF_NULL(func_graph);
136 auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>(black_list_);
137 return graphkernel_backend_cse->Cse(func_graph, func_graph->manager());
138 }
139 } // namespace opt
140 } // namespace mindspore
141