• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/parallel/spliter.h"
18 #include <queue>
19 #include "tools/optimizer/fisson/fisson_util.h"
20 #include "tools/optimizer/parallel/split_strategy.h"
21 namespace mindspore {
22 namespace opt {
GetInstance()23 Spliter *Spliter::GetInstance() {
24   static Spliter spliter;
25   return &spliter;
26 }
27 
VisitNodesInputs(const FuncGraphPtr & func_graph)28 void Spliter::VisitNodesInputs(const FuncGraphPtr &func_graph) {
29   // for every node init it's inputs
30   MS_ASSERT(func_graph != nullptr);
31   for (const auto &node : func_graph->GetOrderedCnodes()) {
32     if (!utils::isa<CNodePtr>(node)) {
33       continue;
34     }
35     for (const auto &input : node->inputs()) {
36       if (!utils::isa<CNodePtr>(input)) {
37         continue;
38       }
39       nodes_inputs_[node].insert(input);
40     }
41   }
42 }
43 
VisitNodesOutputs(const FuncGraphPtr & func_graph)44 void Spliter::VisitNodesOutputs(const FuncGraphPtr &func_graph) {
45   // for every node init it's outputs
46   for (const auto &node : func_graph->GetOrderedCnodes()) {
47     for (const auto &output_item : nodes_inputs_) {
48       if (output_item.first != node) {
49         for (const auto &output : output_item.second) {
50           if (node == output) {
51             nodes_outputs_[node].insert(output_item.first);
52           }
53         }
54       }
55     }
56   }
57 }
58 
RecordGraphInfo(const FuncGraphPtr & func_graph)59 void Spliter::RecordGraphInfo(const FuncGraphPtr &func_graph) {
60   if (func_graph == nullptr) {
61     return;
62   }
63   VisitNodesInputs(func_graph);
64   VisitNodesOutputs(func_graph);
65   for (const auto &node : func_graph->GetOrderedCnodes()) {
66     if (!utils::isa<CNodePtr>(node)) {
67       return;
68     }
69     if (nodes_outputs_[node].size() > kDefaultBatch) {
70       continue;
71     }
72     auto cnode = node->cast<CNodePtr>();
73     auto prim = GetValueNode<PrimitivePtr>(cnode->input(kAnfPrimitiveIndex));
74     MS_ASSERT(prim != nullptr);
75     auto device_type =
76       prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
77     // has been searched
78     if (device_type != kDeviceTypeNone) {
79       return;
80     }
81     // check conv && depthwise_conv
82     if (match_visited_[node] || !IsConv2D(node)) {
83       continue;
84     }
85     int match_num = 0;
86     std::queue<AnfNodePtr> conv_nodes;
87     conv_nodes.push(node);
88     while (true) {
89       if (conv_nodes.empty()) {
90         break;
91       }
92       auto curr_node = conv_nodes.front();
93       conv_nodes.pop();
94       if (match_visited_[curr_node]) {
95         continue;
96       }
97       auto curr_cnode = curr_node->cast<CNodePtr>();
98       match_visited_[curr_node] = true;
99       // visit input, default pre_input is 1, and check it's node type whether is conv2d
100       for (const auto &pre_input_node : nodes_inputs_[curr_node]) {
101         if (match_visited_[pre_input_node] || !IsConv2D(pre_input_node)) {
102           break;
103         }
104         conv_nodes.push(pre_input_node);
105       }
106       // visit output
107       if (nodes_outputs_[curr_cnode].size() > kDefaultBatch) {
108         break;
109       }
110       for (const auto &post_output_node : nodes_outputs_[curr_node]) {
111         if (match_visited_[post_output_node] || !IsConv2D(post_output_node)) {
112           break;
113         }
114         conv_nodes.push(post_output_node);
115       }
116       match_num++;
117     }
118     if (match_num != 0) {
119       match_numbers_.insert(match_num);
120     }
121   }
122 }
123 
UpdateNodeOutputs(const std::string & input_node_name,const AnfNodePtr & candidate_output)124 void Spliter::UpdateNodeOutputs(const std::string &input_node_name, const AnfNodePtr &candidate_output) {
125   if (candidate_output == nullptr) {
126     return;
127   }
128   if (graph_node_outputs_.find(input_node_name) != graph_node_outputs_.end()) {
129     std::vector<AnfNodePtr>::iterator it;
130     it =
131       find(graph_node_outputs_[input_node_name].begin(), graph_node_outputs_[input_node_name].end(), candidate_output);
132     if (it != graph_node_outputs_[input_node_name].end()) {
133       return;
134     }
135   }
136   graph_node_outputs_[input_node_name].push_back(candidate_output);
137 }
138 
UpdateNodeInputShapes(const std::string & node_name,const std::vector<ShapeVector> & input_shapes)139 void Spliter::UpdateNodeInputShapes(const std::string &node_name, const std::vector<ShapeVector> &input_shapes) {
140   graph_node_input_shapes_[node_name] = (input_shapes);
141 }
142 
UpdateNodeOutputShapes(const std::string & node_name,const std::vector<ShapeVector> & output_shapes)143 void Spliter::UpdateNodeOutputShapes(const std::string &node_name, const std::vector<ShapeVector> &output_shapes) {
144   graph_node_output_shapes_[node_name] = (output_shapes);
145 }
146 
147 }  // namespace opt
148 }  // namespace mindspore
149