1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/parallel/spliter.h"
19 #include <queue>
20 #include "tools/optimizer/fisson/fisson_util.h"
21 #include "tools/optimizer/parallel/split_strategy.h"
22 #include "ops/op_utils.h"
23
24 namespace mindspore {
25 namespace opt {
GetInstance()26 Spliter *Spliter::GetInstance() {
27 static Spliter spliter;
28 return &spliter;
29 }
30
VisitNodesInputs(const FuncGraphPtr & func_graph)31 void Spliter::VisitNodesInputs(const FuncGraphPtr &func_graph) {
32 // for every node init it's inputs
33 MS_ASSERT(func_graph != nullptr);
34 for (const auto &node : func_graph->GetOrderedCnodes()) {
35 if (!utils::isa<CNodePtr>(node)) {
36 continue;
37 }
38 for (const auto &input : node->inputs()) {
39 if (!utils::isa<CNodePtr>(input)) {
40 continue;
41 }
42 nodes_inputs_[node].insert(input);
43 }
44 }
45 }
46
VisitNodesOutputs(const FuncGraphPtr & func_graph)47 void Spliter::VisitNodesOutputs(const FuncGraphPtr &func_graph) {
48 // for every node init it's outputs
49 for (const auto &node : func_graph->GetOrderedCnodes()) {
50 for (const auto &output_item : nodes_inputs_) {
51 if (output_item.first != node) {
52 for (const auto &output : output_item.second) {
53 if (node == output) {
54 nodes_outputs_[node].insert(output_item.first);
55 }
56 }
57 }
58 }
59 }
60 }
61
RecordGraphInfo(const FuncGraphPtr & func_graph)62 void Spliter::RecordGraphInfo(const FuncGraphPtr &func_graph) {
63 MS_ASSERT(func_graph != nullptr);
64 VisitNodesInputs(func_graph);
65 VisitNodesOutputs(func_graph);
66 for (const auto &node : func_graph->GetOrderedCnodes()) {
67 if (!utils::isa<CNodePtr>(node)) {
68 return;
69 }
70 if (nodes_outputs_[node].size() > kDefaultBatch) {
71 continue;
72 }
73 auto cnode = node->cast<CNodePtr>();
74 auto prim = GetValueNode<PrimitivePtr>(cnode->input(kAnfPrimitiveIndex));
75 MS_ASSERT(prim != nullptr);
76 auto device_type =
77 prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
78 // has been searched
79 if (device_type != kDeviceTypeNone) {
80 return;
81 }
82 // check conv && depthwise_conv
83 if (match_visited_[node] || !IsConv2D(node)) {
84 continue;
85 }
86 int match_num = 0;
87 std::queue<AnfNodePtr> conv_nodes;
88 conv_nodes.push(node);
89 while (true) {
90 if (conv_nodes.empty()) {
91 break;
92 }
93 auto curr_node = conv_nodes.front();
94 conv_nodes.pop();
95 if (match_visited_[curr_node]) {
96 continue;
97 }
98 auto curr_cnode = curr_node->cast<CNodePtr>();
99 match_visited_[curr_node] = true;
100 // visit input, default pre_input is 1, and check it's node type whether is conv2d
101 for (const auto &pre_input_node : nodes_inputs_[curr_node]) {
102 if (match_visited_[pre_input_node] || !IsConv2D(pre_input_node)) {
103 break;
104 }
105 conv_nodes.push(pre_input_node);
106 }
107 // visit output
108 if (nodes_outputs_[curr_cnode].size() > kDefaultBatch) {
109 break;
110 }
111 for (const auto &post_output_node : nodes_outputs_[curr_node]) {
112 if (match_visited_[post_output_node] || !IsConv2D(post_output_node)) {
113 break;
114 }
115 conv_nodes.push(post_output_node);
116 }
117 match_num++;
118 }
119 if (match_num != 0) {
120 match_numbers_.insert(match_num);
121 }
122 }
123 }
124
UpdateNodeOutputs(const std::string & input_node_name,const AnfNodePtr & candidate_output)125 void Spliter::UpdateNodeOutputs(const std::string &input_node_name, const AnfNodePtr &candidate_output) {
126 if (candidate_output == nullptr) {
127 return;
128 }
129 if (graph_node_outputs_.find(input_node_name) != graph_node_outputs_.end()) {
130 std::vector<AnfNodePtr>::iterator it;
131 it =
132 find(graph_node_outputs_[input_node_name].begin(), graph_node_outputs_[input_node_name].end(), candidate_output);
133 if (it != graph_node_outputs_[input_node_name].end()) {
134 return;
135 }
136 }
137 graph_node_outputs_[input_node_name].push_back(candidate_output);
138 }
139
UpdateNodeInputShapes(const std::string & node_name,const std::vector<ShapeVector> & input_shapes)140 void Spliter::UpdateNodeInputShapes(const std::string &node_name, const std::vector<ShapeVector> &input_shapes) {
141 graph_node_input_shapes_[node_name] = (input_shapes);
142 }
143
UpdateNodeOutputShapes(const std::string & node_name,const std::vector<ShapeVector> & output_shapes)144 void Spliter::UpdateNodeOutputShapes(const std::string &node_name, const std::vector<ShapeVector> &output_shapes) {
145 graph_node_output_shapes_[node_name] = (output_shapes);
146 }
147
148 } // namespace opt
149 } // namespace mindspore
150