1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/parallel/spliter.h"
18 #include <queue>
19 #include "tools/optimizer/fisson/fisson_util.h"
20 #include "tools/optimizer/parallel/split_strategy.h"
21 namespace mindspore {
22 namespace opt {
GetInstance()23 Spliter *Spliter::GetInstance() {
24 static Spliter spliter;
25 return &spliter;
26 }
27
VisitNodesInputs(const FuncGraphPtr & func_graph)28 void Spliter::VisitNodesInputs(const FuncGraphPtr &func_graph) {
29 // for every node init it's inputs
30 MS_ASSERT(func_graph != nullptr);
31 for (const auto &node : func_graph->GetOrderedCnodes()) {
32 if (!utils::isa<CNodePtr>(node)) {
33 continue;
34 }
35 for (const auto &input : node->inputs()) {
36 if (!utils::isa<CNodePtr>(input)) {
37 continue;
38 }
39 nodes_inputs_[node].insert(input);
40 }
41 }
42 }
43
VisitNodesOutputs(const FuncGraphPtr & func_graph)44 void Spliter::VisitNodesOutputs(const FuncGraphPtr &func_graph) {
45 // for every node init it's outputs
46 for (const auto &node : func_graph->GetOrderedCnodes()) {
47 for (const auto &output_item : nodes_inputs_) {
48 if (output_item.first != node) {
49 for (const auto &output : output_item.second) {
50 if (node == output) {
51 nodes_outputs_[node].insert(output_item.first);
52 }
53 }
54 }
55 }
56 }
57 }
58
RecordGraphInfo(const FuncGraphPtr & func_graph)59 void Spliter::RecordGraphInfo(const FuncGraphPtr &func_graph) {
60 if (func_graph == nullptr) {
61 return;
62 }
63 VisitNodesInputs(func_graph);
64 VisitNodesOutputs(func_graph);
65 for (const auto &node : func_graph->GetOrderedCnodes()) {
66 if (!utils::isa<CNodePtr>(node)) {
67 return;
68 }
69 if (nodes_outputs_[node].size() > kDefaultBatch) {
70 continue;
71 }
72 auto cnode = node->cast<CNodePtr>();
73 auto prim = GetValueNode<PrimitivePtr>(cnode->input(kAnfPrimitiveIndex));
74 MS_ASSERT(prim != nullptr);
75 auto device_type =
76 prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
77 // has been searched
78 if (device_type != kDeviceTypeNone) {
79 return;
80 }
81 // check conv && depthwise_conv
82 if (match_visited_[node] || !IsConv2D(node)) {
83 continue;
84 }
85 int match_num = 0;
86 std::queue<AnfNodePtr> conv_nodes;
87 conv_nodes.push(node);
88 while (true) {
89 if (conv_nodes.empty()) {
90 break;
91 }
92 auto curr_node = conv_nodes.front();
93 conv_nodes.pop();
94 if (match_visited_[curr_node]) {
95 continue;
96 }
97 auto curr_cnode = curr_node->cast<CNodePtr>();
98 match_visited_[curr_node] = true;
99 // visit input, default pre_input is 1, and check it's node type whether is conv2d
100 for (const auto &pre_input_node : nodes_inputs_[curr_node]) {
101 if (match_visited_[pre_input_node] || !IsConv2D(pre_input_node)) {
102 break;
103 }
104 conv_nodes.push(pre_input_node);
105 }
106 // visit output
107 if (nodes_outputs_[curr_cnode].size() > kDefaultBatch) {
108 break;
109 }
110 for (const auto &post_output_node : nodes_outputs_[curr_node]) {
111 if (match_visited_[post_output_node] || !IsConv2D(post_output_node)) {
112 break;
113 }
114 conv_nodes.push(post_output_node);
115 }
116 match_num++;
117 }
118 if (match_num != 0) {
119 match_numbers_.insert(match_num);
120 }
121 }
122 }
123
UpdateNodeOutputs(const std::string & input_node_name,const AnfNodePtr & candidate_output)124 void Spliter::UpdateNodeOutputs(const std::string &input_node_name, const AnfNodePtr &candidate_output) {
125 if (candidate_output == nullptr) {
126 return;
127 }
128 if (graph_node_outputs_.find(input_node_name) != graph_node_outputs_.end()) {
129 std::vector<AnfNodePtr>::iterator it;
130 it =
131 find(graph_node_outputs_[input_node_name].begin(), graph_node_outputs_[input_node_name].end(), candidate_output);
132 if (it != graph_node_outputs_[input_node_name].end()) {
133 return;
134 }
135 }
136 graph_node_outputs_[input_node_name].push_back(candidate_output);
137 }
138
UpdateNodeInputShapes(const std::string & node_name,const std::vector<ShapeVector> & input_shapes)139 void Spliter::UpdateNodeInputShapes(const std::string &node_name, const std::vector<ShapeVector> &input_shapes) {
140 graph_node_input_shapes_[node_name] = (input_shapes);
141 }
142
UpdateNodeOutputShapes(const std::string & node_name,const std::vector<ShapeVector> & output_shapes)143 void Spliter::UpdateNodeOutputShapes(const std::string &node_name, const std::vector<ShapeVector> &output_shapes) {
144 graph_node_output_shapes_[node_name] = (output_shapes);
145 }
146
147 } // namespace opt
148 } // namespace mindspore
149