1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
18
19 #include <algorithm>
20 #include <cmath>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "ir/anf.h"
27 #include "frontend/parallel/status.h"
28
29 namespace mindspore {
30 namespace parallel {
31 // Get the target node's weight for sorting.
GetWeights(const Graph::NodeType & node)32 double GetWeights(const Graph::NodeType &node) {
33 const OperatorRec &op = node.apply;
34
35 if (op.op_type == OperatorType::kRecMatMul) {
36 // For MatMul
37 auto cost_ptr = std::make_shared<CostMatMul>();
38
39 return cost_ptr->GetMinCostIn(op);
40 } else if (op.op_type == OperatorType::kRecConvolution) {
41 // For Convolution
42 auto cost_ptr = std::make_shared<CostConvolution>();
43
44 return cost_ptr->GetMinCostIn(node);
45 } else if (op.op_type == OperatorType::kRecPooling) {
46 // For Pooling
47 auto cost_ptr = std::make_shared<CostPooling>();
48
49 return cost_ptr->GetMinCostIn();
50 } else if (op.op_type == OperatorType::kRecElmWiseOp) {
51 // For TensorAdd
52 auto cost_ptr = std::make_shared<CostTensorAdd>();
53
54 return cost_ptr->GetMinCostIn();
55 } else if (op.op_type == OperatorType::kRecReLU) {
56 // For Activation
57 auto cost_ptr = std::make_shared<CostCommon>();
58
59 return cost_ptr->GetMinCostIn();
60 } else if (op.op_type == OperatorType::kRecReshape) {
61 // For Reshape
62 auto cost_ptr = std::make_shared<CostReshape>();
63
64 return cost_ptr->GetMinCostIn();
65 } else if (op.op_type == OperatorType::kRecBiasAdd) {
66 // For BiasAdd
67 auto cost_ptr = std::make_shared<CostBiasAdd>();
68
69 return cost_ptr->GetMinCostIn();
70 } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp ||
71 op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub ||
72 op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv ||
73 op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) {
74 // For element-wise op
75 auto cost_ptr = std::make_shared<CostCommon>();
76
77 return cost_ptr->GetMinCostIn();
78 } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot ||
79 op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecUnsortedSegmentOp ||
80 op.op_type == OperatorType::kRecSoftmax ||
81 op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
82 op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
83 // For BatchParallel op
84 auto cost_ptr = std::make_shared<CostBatchParallel>();
85
86 return cost_ptr->GetMaxCostIn();
87 } else if (op.op_type == OperatorType::kRecUnkownType) {
88 // For Unkown type
89 return 0.0;
90 } else {
91 MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
92 }
93 }
94
95 // Sort all the nodes by their weights
SortByWeight(const std::shared_ptr<Graph> & graph)96 std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph) {
97 MS_EXCEPTION_IF_NULL(graph);
98
99 std::vector<std::pair<double, size_t>> weight_to_node_index;
100 std::vector<size_t> node_index_by_weights;
101
102 // Get node's weight.
103 for (size_t i = 0; i < graph->nodes.size(); i++) {
104 if (graph->nodes[i].info == kApplication) {
105 const Graph::NodeType &node_ptr = graph->nodes[i];
106 double weight = GetWeights(node_ptr);
107 size_t index = i;
108 weight_to_node_index.push_back(std::make_pair(weight, index));
109 }
110 }
111
112 // Ordering ops aka nodes of the graph
113 std::sort(weight_to_node_index.begin(), weight_to_node_index.end());
114
115 // Store the result in node_index_by_weights.
116 uint64_t size = weight_to_node_index.size();
117 for (uint64_t i = 1; i <= size; i++) {
118 node_index_by_weights.push_back(weight_to_node_index[size - i].second);
119 }
120
121 return node_index_by_weights;
122 }
123
124 // Get optimal strategy to partition the target node
PartitionNode(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::shared_ptr<Graph> & graph)125 StrategyRec PartitionNode(const Graph::NodeType &node,
126 const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
127 const std::shared_ptr<Graph> &graph) {
128 bool enable_conv_chw_partition = false;
129 MS_EXCEPTION_IF_NULL(graph);
130
131 if (node.apply.op_type == OperatorType::kRecMatMul) {
132 // For MatMul
133 auto cost_ptr = std::make_shared<CostMatMul>();
134
135 return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
136 } else if (node.apply.op_type == OperatorType::kRecConvolution) {
137 // For Convolution
138 auto cost_ptr = std::make_shared<CostConvolution>();
139
140 return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, enable_conv_chw_partition);
141 } else if (node.apply.op_type == OperatorType::kRecPooling) {
142 // For Pooling
143 auto cost_ptr = std::make_shared<CostPooling>();
144
145 return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
146 } else if (node.apply.op_type == OperatorType::kRecElmWiseOp) {
147 // For TensorAdd
148 auto cost_ptr = std::make_shared<CostTensorAdd>();
149
150 return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
151 } else if (node.apply.op_type == OperatorType::kRecReLU) {
152 // For Activation
153 auto cost_ptr = std::make_shared<CostCommon>();
154
155 return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
156 } else if (node.apply.op_type == OperatorType::kRecReshape) {
157 // For Reshape
158 auto cost_ptr = std::make_shared<CostReshape>();
159
160 return cost_ptr->GetOptimalStr(node);
161 } else if (node.apply.op_type == OperatorType::kRecBiasAdd) {
162 // For BiasAdd
163 auto cost_ptr = std::make_shared<CostBiasAdd>();
164
165 return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
166 } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp ||
167 node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub ||
168 node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv ||
169 node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) {
170 // For element-wise op
171 auto cost_ptr = std::make_shared<CostCommon>();
172
173 return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
174 } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot ||
175 node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax ||
176 node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
177 node.apply.op_type == kRecUnsortedSegmentOp) {
178 // For BatchParallel type
179 auto cost_ptr = std::make_shared<CostBatchParallel>();
180 return cost_ptr->GetOptimalStr(node);
181 } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
182 // For SoftmaxCrossEntropyWithLogits type
183 auto cost_ptr = std::make_shared<CostSoftmaxCrossEntropyWithLogits>();
184 return cost_ptr->GetOptimalStr(node);
185 } else if (node.apply.op_type == OperatorType::kRecUnkownType) {
186 // For Unkown type
187 StrategyRec default_strategy;
188 return default_strategy;
189 } else {
190 MS_LOG(EXCEPTION) << "Failure: Partition Operator failed.";
191 }
192 }
193
194 // Parttion graph into all devices.
PartitionForAllDevices(const size_t num_device,const double device_memory,const std::shared_ptr<Graph> & graph)195 Status PartitionForAllDevices(const size_t num_device, const double device_memory,
196 const std::shared_ptr<Graph> &graph) {
197 if (num_device < 1) {
198 MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
199 }
200
201 if (num_device > 1024) {
202 MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be larger than 1024.";
203 }
204
205 MS_EXCEPTION_IF_NULL(graph);
206
207 // Comopute iter times
208 int64_t iter_times = static_cast<int64_t>(log2(num_device));
209 if (iter_times > 10) {
210 MS_LOG(EXCEPTION) << "ERROR: Number of iter_times can't be larger than 10.";
211 }
212 // N-cuts loop
213 for (int64_t loop = 0; loop < iter_times; loop++) {
214 // Sort by weights
215 std::vector<size_t> reorder_node_list = SortByWeight(graph);
216
217 // get total node number
218 size_t iter_nodes = reorder_node_list.size();
219
220 // temp vector to map nodename to its strategy.
221 std::vector<std::pair<std::string, StrategyRec>> node_name_to_strategy;
222
223 // Loop for all the nodes
224 for (size_t i_node = 0; i_node < iter_nodes; i_node++) {
225 // get current node's index
226 size_t index = reorder_node_list[i_node];
227
228 Graph::NodeType &node_ptr = graph->nodes[index];
229
230 // Serch optimal strategy to cut this operator. And store the result optimal strategy in graph.
231 graph->nodes[index].apply.str = PartitionNode(node_ptr, node_name_to_strategy, graph);
232
233 // Apply OP Strategy to Tensor Strategy.
234 graph->nodes[index] = ApplyStrToTensor(node_ptr);
235
236 // Note down the node name and its strategy in this loop.
237 auto node_name_to_str =
238 std::pair<std::string, StrategyRec>(graph->nodes[index].name, graph->nodes[index].apply.str);
239 node_name_to_strategy.push_back(node_name_to_str);
240 }
241 }
242
243 if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) {
244 return FAILED;
245 } else {
246 return SUCCESS;
247 }
248 }
249
250 // Apply OP Strategy to Tensor Strategy
ApplyStrToTensor(Graph::NodeType Node)251 Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
252 // Set Node's tensor_parm
253 Node.tensor_parm.tensor_str.str_n = Node.apply.str.outputTensor.str_n;
254 Node.tensor_parm.tensor_str.str_c = Node.apply.str.outputTensor.str_c;
255 Node.tensor_parm.tensor_str.str_h = Node.apply.str.outputTensor.str_h;
256 Node.tensor_parm.tensor_str.str_w = Node.apply.str.outputTensor.str_w;
257
258 // Set input tensors' tersor_parm
259 for (int64_t i = 0; i < 2; i++) {
260 Node.apply.arguments[i].tensor_str.str_n = Node.apply.str.inputTensor[i].str_n;
261 Node.apply.arguments[i].tensor_str.str_c = Node.apply.str.inputTensor[i].str_c;
262 Node.apply.arguments[i].tensor_str.str_h = Node.apply.str.inputTensor[i].str_h;
263 Node.apply.arguments[i].tensor_str.str_w = Node.apply.str.inputTensor[i].str_w;
264 }
265 return Node;
266 }
267
DevicesMemoryControl(const size_t num_device,const double device_memory,const std::shared_ptr<Graph> & graph)268 Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph) {
269 MS_EXCEPTION_IF_NULL(graph);
270 if (num_device == 0) {
271 MS_LOG(EXCEPTION) << "Failure: device number is 0.";
272 }
273
274 uint64_t iter_nodes = graph->nodes.size();
275 double used_memory = 0.0;
276
277 for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
278 if (graph->nodes[i_node].info == 0) {
279 Graph::NodeType &Node = graph->nodes[i_node];
280 for (int64_t index = 0; index < 2; index++) {
281 used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
282 Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c *
283 Node.apply.arguments[index].tensor_str.str_h * Node.apply.arguments[index].tensor_shape.shape_h *
284 Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w *
285 GetDataTypeSize(Node.apply.arguments[index].tensor_type);
286 }
287 }
288 }
289
290 if (device_memory < (used_memory / num_device)) {
291 MS_LOG(EXCEPTION) << "Failure: Out of memory!";
292 return FAILED;
293 } else {
294 return SUCCESS;
295 }
296 }
297
GetDataTypeSize(const TensorType & type)298 size_t GetDataTypeSize(const TensorType &type) {
299 switch (type) {
300 case kInt8:
301 return sizeof(int64_t);
302 case kFloat16:
303 return sizeof(float) / 2;
304 case kFloat32:
305 return sizeof(float);
306 case kDouble64:
307 return sizeof(double);
308 default:
309 MS_LOG(EXCEPTION) << "GetDataTypeSize Failed. Unexpected type";
310 }
311 }
312 } // namespace parallel
313 } // namespace mindspore
314