• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "backend/session/anf_runtime_algorithm.h"
25 #include "runtime/device/kernel_info.h"
26 
27 namespace mindspore {
28 namespace opt {
29 namespace {
IsCNodePrimitveEqual(const CNodePtr & main,const CNodePtr & node,const std::vector<PrimitivePtr> & black_list)30 bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std::vector<PrimitivePtr> &black_list) {
31   auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
32   auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
33   if (main_primitive != nullptr && node_primitive != nullptr) {
34     // Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op
35     // alone can prevent some redundant output case (input -> reshape -> output).
36     if (main_primitive->name() != node_primitive->name() ||
37         std::any_of(black_list.begin(), black_list.end(),
38                     [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
39       return false;
40     }
41 
42     auto main_attrs = main_primitive->attrs();
43     auto node_attrs = node_primitive->attrs();
44 
45     std::vector<std::string> exclude_attrs{"IsFeatureMapOutput", "IsFeatureMapInputList", "pri_format"};
46     for (auto &attr : exclude_attrs) {
47       main_attrs.erase(attr);
48       node_attrs.erase(attr);
49     }
50 
51     if (main_attrs.size() != node_attrs.size()) {
52       return false;
53     }
54 
55     auto all = std::all_of(main_attrs.begin(), main_attrs.end(),
56                            [&node_attrs](const std::pair<std::string, ValuePtr> &item) -> bool {
57                              if (item.second == nullptr) {
58                                return false;
59                              }
60                              auto iter = node_attrs.find(item.first);
61                              if (iter == node_attrs.end()) {
62                                return false;
63                              }
64                              return *item.second == *iter->second;
65                            });
66     return all;
67   }
68 
69   return *main->inputs()[0] == *node->inputs()[0];
70 }
71 }  // namespace
72 
CheckEqualKernelBuildInfo(const AnfNodePtr & main,const AnfNodePtr & node) const73 bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
74   MS_EXCEPTION_IF_NULL(main);
75   MS_EXCEPTION_IF_NULL(node);
76 
77   if (!AnfAlgo::IsNodeInGraphKernel(main)) {
78     return BackendCSE::CheckEqualKernelBuildInfo(main, node);
79   }
80 
81   auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
82   auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
83   if (main_kernel_info == nullptr && node_kernel_info == nullptr) {
84     return true;
85   }
86 
87   if (main_kernel_info != nullptr && node_kernel_info != nullptr) {
88     auto main_build_info = main_kernel_info->GetMutableSelectKernelBuildInfo();
89     auto node_build_info = node_kernel_info->GetMutableSelectKernelBuildInfo();
90     if (main_build_info == nullptr && node_build_info == nullptr) {
91       return true;
92     }
93 
94     if (main_build_info == nullptr || node_build_info == nullptr) {
95       return false;
96     }
97 
98     if (main_build_info->processor() != node_build_info->processor()) {
99       return false;
100     }
101 
102     return main_build_info->IsSimilarityKernelBuildInfo(*node_build_info);
103   }
104   return false;
105 }
106 
CheckEqualCnodeInputs(const AnfNodePtr & main,const AnfNodePtr & node) const107 bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const {
108   auto c_main = main->cast<CNodePtr>();
109   MS_EXCEPTION_IF_NULL(c_main);
110   auto c_node = node->cast<CNodePtr>();
111   MS_EXCEPTION_IF_NULL(c_node);
112 
113   if (!AnfAlgo::IsNodeInGraphKernel(c_main)) {
114     return BackendCSE::CheckEqualCnodeInputs(main, node);
115   }
116 
117   const auto &inp1 = c_main->inputs();
118   const auto &inp2 = c_node->inputs();
119   if (inp1.size() != inp2.size()) {
120     return false;
121   }
122   for (size_t j = 1; j < inp1.size(); j++) {
123     auto inp1_j = inp1[j];
124     auto inp2_j = inp2[j];
125     MS_EXCEPTION_IF_NULL(inp1_j);
126     MS_EXCEPTION_IF_NULL(inp2_j);
127     if (!(*inp1_j == *inp2_j)) {
128       return false;
129     }
130   }
131   return IsCNodePrimitveEqual(c_main, c_node, black_list_);
132 }
133 
Run(const FuncGraphPtr & func_graph)134 bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) {
135   MS_EXCEPTION_IF_NULL(func_graph);
136   auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>(black_list_);
137   return graphkernel_backend_cse->Cse(func_graph, func_graph->manager());
138 }
139 }  // namespace opt
140 }  // namespace mindspore
141