• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include <iostream>
25 
26 #include "frontend/parallel/status.h"
27 #include "frontend/parallel/ops_info/ops_utils.h"
28 #include "frontend/parallel/step_parallel_utils.h"
29 #include "frontend/parallel/auto_parallel/stage_compute.h"
30 #include "include/common/utils/parallel_context.h"
31 
32 namespace mindspore {
33 namespace parallel {
34 // Get the target node's weight for sorting.
GetWeights(const Graph::NodeType & node)35 double GetWeights(const Graph::NodeType &node) {
36   const OperatorRec &op = node.apply;
37 
38   if (op.op_type == OperatorType::kRecMatMul) {
39     // For MatMul
40     auto cost_ptr = std::make_shared<CostMatMul>();
41 
42     return cost_ptr->GetMaxCostIn(op);
43   } else if (op.op_type == OperatorType::kRecBatchMatMul) {
44     // For BatchMatMul
45     auto cost_ptr = std::make_shared<CostBatchMatMul>();
46 
47     return cost_ptr->GetMaxCostIn(node);
48   } else if (op.op_type == OperatorType::kRecConvolution) {
49     // For Convolution
50     auto cost_ptr = std::make_shared<CostConvolution>();
51 
52     return cost_ptr->GetMinCostIn(node);
53   } else if (op.op_type == OperatorType::kRecPooling) {
54     // For Pooling
55     auto cost_ptr = std::make_shared<CostPooling>();
56 
57     return cost_ptr->GetMinCostIn();
58   } else if (op.op_type == OperatorType::kRecElmWiseOp) {
59     // For TensorAdd
60     auto cost_ptr = std::make_shared<CostTensorAdd>();
61 
62     return cost_ptr->GetMinCostIn();
63   } else if (op.op_type == OperatorType::kRecReLU) {
64     // For Activation
65     auto cost_ptr = std::make_shared<CostCommon>();
66 
67     return cost_ptr->GetMinCostIn();
68   } else if (op.op_type == OperatorType::kRecReshape) {
69     // For Reshape
70     auto cost_ptr = std::make_shared<CostReshape>();
71 
72     return cost_ptr->GetMinCostIn();
73   } else if (op.op_type == OperatorType::kRecBiasAdd) {
74     // For BiasAdd
75     auto cost_ptr = std::make_shared<CostBiasAdd>();
76 
77     return cost_ptr->GetMinCostIn();
78   } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp ||
79              op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub ||
80              op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv ||
81              op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) {
82     // For element-wise op
83     auto cost_ptr = std::make_shared<CostCommon>();
84 
85     return cost_ptr->GetMinCostIn();
86   } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot ||
87              op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecUnsortedSegmentOp ||
88              op.op_type == OperatorType::kRecSoftmax || op.op_type == OperatorType::kRecBatchParallel ||
89              op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
90              op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
91     // For BatchParallel op
92     auto cost_ptr = std::make_shared<CostBatchParallel>();
93 
94     return cost_ptr->GetMaxCostIn();
95   } else if (op.op_type == OperatorType::kRecUnknownType) {
96     // For Unknown type
97     return 0.0;
98   } else if (op.op_type == OperatorType::kRecVirtual) {
99     // For Unknown type
100     return 0.0;
101   } else if (op.op_type == OperatorType::kRecStandAlone) {
102     // For StandAlone type
103     return 0.0;
104   } else {
105     MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
106   }
107 }
108 
109 // Sort all the nodes by their weights
SortByWeight(const std::shared_ptr<Graph> & graph)110 std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph) {
111   MS_EXCEPTION_IF_NULL(graph);
112 
113   std::vector<std::pair<double, size_t>> weight_to_node_index;
114   std::vector<size_t> node_index_by_weights;
115 
116   // Get node's weight.
117   for (size_t pos = 0; pos < graph->nodes.size(); pos++) {
118     if (graph->nodes[pos].info == kApplication) {
119       const Graph::NodeType &node_ptr = graph->nodes[pos];
120       double weight;
121       bool mem_first = false;
122       if (g_device_manager->DeviceNum() > SIZE_THIRTY_TWO && graph->micro_batch_size < INT64_EIGHT) {
123         mem_first = true;
124       }
125       if (PARTITION_ORDER == PartitionOrder::TopologyOrder && !mem_first) {
126         weight = (node_ptr.apply.op_type == OperatorType::kRecUnknownType) ? DOUBLE_LOWEST : pos;
127       } else {
128         weight = GetWeights(node_ptr);
129       }
130       size_t index = pos;
131       weight_to_node_index.push_back(std::make_pair(weight, index));
132     }
133   }
134 
135   // Ordering ops aka nodes of the graph
136   std::sort(weight_to_node_index.begin(), weight_to_node_index.end());
137 
138   // Store the result in node_index_by_weights.
139   uint64_t size = weight_to_node_index.size();
140   for (uint64_t i = 1; i <= size; i++) {
141     node_index_by_weights.push_back(weight_to_node_index[size - i].second);
142   }
143 
144   return node_index_by_weights;
145 }
146 
147 // Get optimal strategy to partition the target node
PartitionNode(Graph::NodeType node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::shared_ptr<Graph> & graph,const bool isTraining)148 StrategyRec PartitionNode(Graph::NodeType node,
149                           const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
150                           const std::shared_ptr<Graph> &graph, const bool isTraining) {
151   bool enable_conv_chw_partition = false;
152   MS_EXCEPTION_IF_NULL(graph);
153 
154   if (node.apply.op_type == OperatorType::kRecMatMul) {
155     if (graph->dyn_shape_tmp_fix) {
156       if (node.param_name.find(".projection.weight") != std::string::npos) {
157         node.apply.str.inputTensor[0].str_w /= SIZE_TWO;
158         node.apply.str.inputTensor[1].str_h /= SIZE_TWO;
159         return node.apply.str;
160       }
161       if (node.param_name.find(".mapping.weight") != std::string::npos) {
162         node.apply.str.inputTensor[1].str_w /= SIZE_TWO;
163         node.apply.str.outputTensor.str_w /= SIZE_TWO;
164         return node.apply.str;
165       }
166       if (node.param_name.find(".attention.dense2.weight") != std::string::npos) {
167         node.apply.str.inputTensor[1].str_w /= SIZE_TWO;
168         node.apply.str.outputTensor.str_w /= SIZE_TWO;
169         return node.apply.str;
170       }
171       if (node.param_name.find(".attention_norm.weight") != std::string::npos) {
172         node.apply.str.inputTensor[1].str_w /= SIZE_TWO;
173         node.apply.str.outputTensor.str_w /= SIZE_TWO;
174         return node.apply.str;
175       }
176       if (node.param_name.find(".norm_out.weight") != std::string::npos) {
177         return node.apply.str;
178       }
179     }
180 
181     // For MatMul
182     auto cost_ptr = std::make_shared<CostMatMul>();
183 
184     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, isTraining);
185   } else if (node.apply.op_type == OperatorType::kRecBatchMatMul) {
186     if (graph->dyn_shape_tmp_fix) {
187       if (node.param_name.find(".projection.weight") != std::string::npos) {
188         node.apply.str.inputTensor[0].str_w /= SIZE_TWO;
189         node.apply.str.inputTensor[1].str_h /= SIZE_TWO;
190         return node.apply.str;
191       }
192       if (node.param_name.find(".mapping.weight") != std::string::npos) {
193         node.apply.str.inputTensor[1].str_w /= SIZE_TWO;
194         node.apply.str.outputTensor.str_w /= SIZE_TWO;
195         return node.apply.str;
196       }
197 
198       bool same_inputs = false;
199       bool projection_bias_bmm = false;
200       bool mapping_bias_bmm = false;
201       for (size_t idx = 0; idx < node.node_in.size(); idx++) {
202         if (idx == node.node_in.size() - 1) {
203           break;
204         }
205         for (size_t idx_bis = idx + 1; idx_bis < node.node_in.size(); idx_bis++) {
206           if (node.node_in[idx] == node.node_in[idx_bis]) {
207             same_inputs = true;
208             break;
209           }
210         }
211         if (same_inputs) {
212           break;
213         }
214       }
215       if (same_inputs) {
216         return node.apply.str;
217       }
218 
219       for (size_t idx = 0; idx < node.node_in.size(); idx++) {
220         auto incoming_node_idx = node.node_in[idx];
221         if (graph->nodes[incoming_node_idx].param_name.find(".projection.bias") != std::string::npos) {
222           projection_bias_bmm = true;
223           break;
224         }
225         if (graph->nodes[incoming_node_idx].param_name.find(".mapping.bias") != std::string::npos) {
226           mapping_bias_bmm = true;
227           break;
228         }
229       }
230       if (projection_bias_bmm) {
231         node.apply.str.inputTensor[0].str_w /= SIZE_TWO;
232         node.apply.str.inputTensor[1].str_h /= SIZE_TWO;
233         return node.apply.str;
234       }
235       if (mapping_bias_bmm) {
236         node.apply.str.inputTensor[1].str_w /= SIZE_TWO;
237         node.apply.str.outputTensor.str_w /= SIZE_TWO;
238         return node.apply.str;
239       }
240     }
241 
242     // For BatchMatMul
243     auto cost_ptr = std::make_shared<CostBatchMatMul>();
244 
245     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, isTraining);
246   } else if (node.apply.op_type == OperatorType::kRecConvolution) {
247     // For Convolution
248     auto cost_ptr = std::make_shared<CostConvolution>();
249 
250     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, enable_conv_chw_partition);
251   } else if (node.apply.op_type == OperatorType::kRecPooling) {
252     // For Pooling
253     auto cost_ptr = std::make_shared<CostPooling>();
254 
255     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
256   } else if (node.apply.op_type == OperatorType::kRecElmWiseOp) {
257     // For TensorAdd
258     auto cost_ptr = std::make_shared<CostTensorAdd>();
259 
260     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
261   } else if (node.apply.op_type == OperatorType::kRecReLU) {
262     // For Activation
263     auto cost_ptr = std::make_shared<CostCommon>();
264 
265     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
266   } else if (node.apply.op_type == OperatorType::kRecReshape) {
267     // For Reshape
268     auto cost_ptr = std::make_shared<CostReshape>();
269 
270     return cost_ptr->GetOptimalStr(node);
271   } else if (node.apply.op_type == OperatorType::kRecBiasAdd) {
272     // For BiasAdd
273     auto cost_ptr = std::make_shared<CostBiasAdd>();
274 
275     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
276   } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp ||
277              node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub ||
278              node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv ||
279              node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) {
280     // For element-wise op
281     auto cost_ptr = std::make_shared<CostCommon>();
282 
283     return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
284   } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot ||
285              node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax ||
286              node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
287              node.apply.op_type == kRecUnsortedSegmentOp || node.apply.op_type == OperatorType::kRecBatchParallel ||
288              node.apply.op_type == OperatorType::kRecVirtual) {
289     // For BatchParallel type
290     auto cost_ptr = std::make_shared<CostBatchParallel>();
291     return cost_ptr->GetOptimalStr(node);
292   } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
293     // For SoftmaxCrossEntropyWithLogits type
294     auto cost_ptr = std::make_shared<CostSoftmaxCrossEntropyWithLogits>();
295     return cost_ptr->GetOptimalStr(node);
296   } else if (node.apply.op_type == OperatorType::kRecUnknownType) {
297     // For Unknown type
298     StrategyRec default_strategy;
299     return default_strategy;
300   } else if (node.apply.op_type == OperatorType::kRecStandAlone) {
301     // For stand_alone type
302     StrategyRec default_strategy;
303     return default_strategy;
304   } else {
305     MS_LOG(EXCEPTION) << "Failure: Partition Operator failed.";
306   }
307 }
308 
GetOneLoopStrategy(size_t op_inputs_num,const StrategyRec & old_str,StrategyRec new_str)309 StrategyRec GetOneLoopStrategy(size_t op_inputs_num, const StrategyRec &old_str, StrategyRec new_str) {
310   for (size_t i = 0; i < op_inputs_num; i++) {
311     if (abs(old_str.inputTensor[i].str_n) > EPS && abs(old_str.inputTensor[i].str_c) > EPS &&
312         abs(old_str.inputTensor[i].str_h) > EPS && abs(old_str.inputTensor[i].str_w) > EPS) {
313       new_str.inputTensor[i].str_n = new_str.inputTensor[i].str_n / old_str.inputTensor[i].str_n;
314       new_str.inputTensor[i].str_c = new_str.inputTensor[i].str_c / old_str.inputTensor[i].str_c;
315       new_str.inputTensor[i].str_h = new_str.inputTensor[i].str_h / old_str.inputTensor[i].str_h;
316       new_str.inputTensor[i].str_w = new_str.inputTensor[i].str_w / old_str.inputTensor[i].str_w;
317     }
318   }
319 
320   if (old_str.outputTensor.str_n > EPS && old_str.outputTensor.str_c > EPS && old_str.outputTensor.str_h > EPS &&
321       old_str.outputTensor.str_w > EPS) {
322     new_str.outputTensor.str_n = new_str.outputTensor.str_n / old_str.outputTensor.str_n;
323     new_str.outputTensor.str_c = new_str.outputTensor.str_c / old_str.outputTensor.str_c;
324     new_str.outputTensor.str_h = new_str.outputTensor.str_h / old_str.outputTensor.str_h;
325     new_str.outputTensor.str_w = new_str.outputTensor.str_w / old_str.outputTensor.str_w;
326   }
327 
328   return new_str;
329 }
330 
ChangeStrategy(Graph::NodeType Node,size_t n_cut)331 Graph::NodeType ChangeStrategy(Graph::NodeType Node, size_t n_cut) {
332   if (n_cut >= Node.apply.strs.size()) {
333     MS_LOG(EXCEPTION) << "Strategy not available";
334   }
335   Node.apply.str = Node.apply.strs[n_cut];
336   Node = ApplyStrToTensor(Node);
337 
338   return Node;
339 }
340 
GetStratNumber(const Graph::NodeType & Node)341 size_t GetStratNumber(const Graph::NodeType &Node) { return Node.apply.strs.size(); }
342 
PartitionPipelineStages(double device_memory,const std::shared_ptr<Graph> & graph)343 void PartitionPipelineStages(double device_memory, const std::shared_ptr<Graph> &graph) {
344   if (!ENABLE_PIPE_ALGO) {
345     size_t n_stage = LongToSize(parallel::ParallelContext::GetInstance()->pipeline_stage_split_num());
346     size_t n_node = graph->nodes.size();
347     size_t roll_back = FloatToSize(log2(n_stage));
348 
349     MS_LOG(INFO) << "ROLLING BACK ACCORDING TO STAGE NUMBER (" << n_stage << ") BY " << roll_back << " LEVELS"
350                  << std::endl;
351     for (size_t i_node = 0; i_node < n_node; ++i_node) {
352       Graph::NodeType &node_ptr = graph->nodes[i_node];
353       size_t n_cut = GetStratNumber(graph->nodes[i_node]) - roll_back - 1;
354       graph->nodes[i_node] = ChangeStrategy(node_ptr, n_cut);
355     }
356   }
357 }
358 
359 // Partition graph into all devices.
PartitionForAllDevices(size_t num_device,double device_memory,const std::shared_ptr<Graph> & graph,bool isTraining,const FuncGraphPtr & root)360 Status PartitionForAllDevices(size_t num_device, double device_memory, const std::shared_ptr<Graph> &graph,
361                               bool isTraining, const FuncGraphPtr &root) {
362   if (num_device < 1) {
363     MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
364   }
365 
366   if (num_device > 1024) {
367     MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be larger than 1024.";
368   }
369 
370   MS_EXCEPTION_IF_NULL(graph);
371 
372   // Comopute iter times
373   int64_t iter_times = static_cast<int64_t>(log2(num_device));
374   if (iter_times > 10) {
375     MS_LOG(EXCEPTION) << "ERROR: Number of iter_times can't be larger than 10.";
376   }
377 
378   // N-cuts loop
379   for (int64_t loop = 0; loop < iter_times; loop++) {
380     // Sort by weights
381     std::vector<size_t> reorder_node_list = SortByWeight(graph);
382 
383     // get total node number
384     size_t iter_nodes = reorder_node_list.size();
385 
386     // temp vector to map nodename to its strategy.
387     std::vector<std::pair<std::string, StrategyRec>> node_name_to_strategy;
388 
389     // Loop for all the nodes
390     for (size_t i_node = 0; i_node < iter_nodes; i_node++) {
391       // get current node's index
392       size_t index = reorder_node_list[i_node];
393 
394       Graph::NodeType &node_ptr = graph->nodes[index];
395 
396       // 2-parts partitioning StrategyRec of the last loop
397       StrategyRec old_str = graph->nodes[index].apply.str;
398 
399       // Save first strategy too
400       if (graph->nodes[index].apply.strs.size() == 0) {
401         graph->nodes[index].apply.strs.push_back(old_str);
402       }
403 
404       MS_LOG(INFO) << "------------Node_name: " << graph->nodes[index].name << " -------------";
405 
406       // Search optimal strategy to cut this operator. And store the result optimal strategy in graph.
407       graph->nodes[index].apply.str = PartitionNode(node_ptr, node_name_to_strategy, graph, isTraining);
408       graph->nodes[index].apply.strs.push_back(graph->nodes[index].apply.str);
409 
410       // Get Current 2-parts partitioning strategy of this loop
411       size_t op_inputs_num = graph->nodes[index].node_in.size();
412       StrategyRec one_loop_strategyrec = GetOneLoopStrategy(op_inputs_num, old_str, graph->nodes[index].apply.str);
413 
414       // Apply OP Strategy to Tensor Strategy.
415       graph->nodes[index] = ApplyStrToTensor(node_ptr);
416 
417       // Note down the node name and its strategy in this loop.
418       auto node_name_to_str = std::pair<std::string, StrategyRec>(graph->nodes[index].name, one_loop_strategyrec);
419       node_name_to_strategy.push_back(node_name_to_str);
420     }
421   }
422 
423   // Auto pipeline
424   size_t new_stage_num = ParallelSuggestion(root, graph);
425   if (parallel::ParallelContext::GetInstance()->auto_pipeline()) {
426     ChangeStageNumber(root, new_stage_num);
427   }
428 
429   // Partition stages
430   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
431     PartitionPipelineStages(device_memory, graph);
432   }
433 
434   DevicesMemoryControl(num_device, device_memory, graph);
435   return SUCCESS;
436 }
437 
438 // Apply OP Strategy to Tensor Strategy
ApplyStrToTensor(Graph::NodeType Node)439 Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
440   // Set Node's tensor_parm
441   Node.tensor_parm.tensor_str.str_n = Node.apply.str.outputTensor.str_n;
442   Node.tensor_parm.tensor_str.str_c = Node.apply.str.outputTensor.str_c;
443   Node.tensor_parm.tensor_str.str_h = Node.apply.str.outputTensor.str_h;
444   Node.tensor_parm.tensor_str.str_w = Node.apply.str.outputTensor.str_w;
445 
446   // Set input tensors' tersor_parm
447   for (int64_t i = 0; i < 2; i++) {
448     Node.apply.arguments[i].tensor_str.str_n = Node.apply.str.inputTensor[i].str_n;
449     Node.apply.arguments[i].tensor_str.str_c = Node.apply.str.inputTensor[i].str_c;
450     Node.apply.arguments[i].tensor_str.str_h = Node.apply.str.inputTensor[i].str_h;
451     Node.apply.arguments[i].tensor_str.str_w = Node.apply.str.inputTensor[i].str_w;
452   }
453   return Node;
454 }
455 
DevicesMemoryControl(const size_t num_device,const double device_memory,const std::shared_ptr<Graph> & graph)456 void DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph) {
457   MS_EXCEPTION_IF_NULL(graph);
458   if (num_device == 0) {
459     MS_LOG(EXCEPTION) << "Failure: device number is 0.";
460   }
461 
462   uint64_t iter_nodes = graph->nodes.size();
463   double used_memory = 0.0;
464 
465   for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
466     if (graph->nodes[i_node].info == InfoType::kApplication) {
467       Graph::NodeType &Node = graph->nodes[i_node];
468       for (int64_t index = 0; index < 2; index++) {
469         used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
470                        Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c *
471                        Node.apply.arguments[index].tensor_str.str_h * Node.apply.arguments[index].tensor_shape.shape_h *
472                        Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w *
473                        GetDataTypeSize(Node.apply.arguments[index].tensor_type);
474       }
475     }
476   }
477 
478   if (device_memory < (used_memory / num_device)) {
479     MS_LOG(WARNING) << "It is estimated that the task may collapse due to out of memory!";
480   }
481 }
482 
GetDataTypeSize(const TensorType & type)483 size_t GetDataTypeSize(const TensorType &type) {
484   switch (type) {
485     case kInt8:
486       return sizeof(int64_t);
487     case kFloat16:
488       return sizeof(float) / 2;
489     case kFloat32:
490       return sizeof(float);
491     case kDouble64:
492       return sizeof(double);
493     default:
494       MS_LOG(EXCEPTION) << "GetDataTypeSize Failed. Unexpected type";
495   }
496 }
497 }  // namespace parallel
498 }  // namespace mindspore
499