• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/parallel/spliter.h"
19 #include <queue>
20 #include "tools/optimizer/fisson/fisson_util.h"
21 #include "tools/optimizer/parallel/split_strategy.h"
22 #include "ops/op_utils.h"
23 
24 namespace mindspore {
25 namespace opt {
GetInstance()26 Spliter *Spliter::GetInstance() {
27   static Spliter spliter;
28   return &spliter;
29 }
30 
VisitNodesInputs(const FuncGraphPtr & func_graph)31 void Spliter::VisitNodesInputs(const FuncGraphPtr &func_graph) {
32   // for every node init it's inputs
33   MS_ASSERT(func_graph != nullptr);
34   for (const auto &node : func_graph->GetOrderedCnodes()) {
35     if (!utils::isa<CNodePtr>(node)) {
36       continue;
37     }
38     for (const auto &input : node->inputs()) {
39       if (!utils::isa<CNodePtr>(input)) {
40         continue;
41       }
42       nodes_inputs_[node].insert(input);
43     }
44   }
45 }
46 
VisitNodesOutputs(const FuncGraphPtr & func_graph)47 void Spliter::VisitNodesOutputs(const FuncGraphPtr &func_graph) {
48   // for every node init it's outputs
49   for (const auto &node : func_graph->GetOrderedCnodes()) {
50     for (const auto &output_item : nodes_inputs_) {
51       if (output_item.first != node) {
52         for (const auto &output : output_item.second) {
53           if (node == output) {
54             nodes_outputs_[node].insert(output_item.first);
55           }
56         }
57       }
58     }
59   }
60 }
61 
RecordGraphInfo(const FuncGraphPtr & func_graph)62 void Spliter::RecordGraphInfo(const FuncGraphPtr &func_graph) {
63   MS_ASSERT(func_graph != nullptr);
64   VisitNodesInputs(func_graph);
65   VisitNodesOutputs(func_graph);
66   for (const auto &node : func_graph->GetOrderedCnodes()) {
67     if (!utils::isa<CNodePtr>(node)) {
68       return;
69     }
70     if (nodes_outputs_[node].size() > kDefaultBatch) {
71       continue;
72     }
73     auto cnode = node->cast<CNodePtr>();
74     auto prim = GetValueNode<PrimitivePtr>(cnode->input(kAnfPrimitiveIndex));
75     MS_ASSERT(prim != nullptr);
76     auto device_type =
77       prim->GetAttr(ops::kDeviceType) != nullptr ? GetValue<int>(prim->GetAttr(ops::kDeviceType)) : kDeviceTypeNone;
78     // has been searched
79     if (device_type != kDeviceTypeNone) {
80       return;
81     }
82     // check conv && depthwise_conv
83     if (match_visited_[node] || !IsConv2D(node)) {
84       continue;
85     }
86     int match_num = 0;
87     std::queue<AnfNodePtr> conv_nodes;
88     conv_nodes.push(node);
89     while (true) {
90       if (conv_nodes.empty()) {
91         break;
92       }
93       auto curr_node = conv_nodes.front();
94       conv_nodes.pop();
95       if (match_visited_[curr_node]) {
96         continue;
97       }
98       auto curr_cnode = curr_node->cast<CNodePtr>();
99       match_visited_[curr_node] = true;
100       // visit input, default pre_input is 1, and check it's node type whether is conv2d
101       for (const auto &pre_input_node : nodes_inputs_[curr_node]) {
102         if (match_visited_[pre_input_node] || !IsConv2D(pre_input_node)) {
103           break;
104         }
105         conv_nodes.push(pre_input_node);
106       }
107       // visit output
108       if (nodes_outputs_[curr_cnode].size() > kDefaultBatch) {
109         break;
110       }
111       for (const auto &post_output_node : nodes_outputs_[curr_node]) {
112         if (match_visited_[post_output_node] || !IsConv2D(post_output_node)) {
113           break;
114         }
115         conv_nodes.push(post_output_node);
116       }
117       match_num++;
118     }
119     if (match_num != 0) {
120       match_numbers_.insert(match_num);
121     }
122   }
123 }
124 
UpdateNodeOutputs(const std::string & input_node_name,const AnfNodePtr & candidate_output)125 void Spliter::UpdateNodeOutputs(const std::string &input_node_name, const AnfNodePtr &candidate_output) {
126   if (candidate_output == nullptr) {
127     return;
128   }
129   if (graph_node_outputs_.find(input_node_name) != graph_node_outputs_.end()) {
130     std::vector<AnfNodePtr>::iterator it;
131     it =
132       find(graph_node_outputs_[input_node_name].begin(), graph_node_outputs_[input_node_name].end(), candidate_output);
133     if (it != graph_node_outputs_[input_node_name].end()) {
134       return;
135     }
136   }
137   graph_node_outputs_[input_node_name].push_back(candidate_output);
138 }
139 
UpdateNodeInputShapes(const std::string & node_name,const std::vector<ShapeVector> & input_shapes)140 void Spliter::UpdateNodeInputShapes(const std::string &node_name, const std::vector<ShapeVector> &input_shapes) {
141   graph_node_input_shapes_[node_name] = (input_shapes);
142 }
143 
UpdateNodeOutputShapes(const std::string & node_name,const std::vector<ShapeVector> & output_shapes)144 void Spliter::UpdateNodeOutputShapes(const std::string &node_name, const std::vector<ShapeVector> &output_shapes) {
145   graph_node_output_shapes_[node_name] = (output_shapes);
146 }
147 
148 }  // namespace opt
149 }  // namespace mindspore
150