• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/parallel/split_strategy.h"
19 #include <vector>
20 #include <unordered_map>
21 #include <string>
22 #include "nnacl/op_base.h"
23 #include "src/common/log_util.h"
24 
25 namespace mindspore {
26 namespace opt {
27 
ApproximateFLOPs(const std::vector<int64_t> & strides,const std::vector<int64_t> & input_shape,const std::vector<int64_t> & weight_shape)28 int64_t ApproximateFLOPs(const std::vector<int64_t> &strides, const std::vector<int64_t> &input_shape,
29                          const std::vector<int64_t> &weight_shape) {
30   MS_CHECK_GT(strides.size(), 1, 0);
31   MS_CHECK_GT(input_shape.size(), kInputSize2, 0);
32   MS_CHECK_GT(weight_shape.size(), kInputSize1, 0);
33   int64_t input_h = input_shape.at(kShapeH);
34   int64_t input_w = input_shape.at(kShapeW);
35   int64_t in_c = input_shape.at(kShapeC);
36   int64_t out_c = weight_shape.at(kShapeN);
37   int64_t k_h = weight_shape.at(kShapeH);
38   int64_t k_w = weight_shape.at(kShapeW);
39   int64_t stride_h = strides.at(kIndexH);
40   int64_t stride_w = strides.at(kIndexW);
41   if (stride_h == 0 || stride_w == 0) {
42     MS_LOG(ERROR) << "divisor is zero.";
43     return 0;
44   }
45   return (input_h / stride_h) * (input_w / stride_w) * in_c * k_h * k_w * out_c / kPerFlops;
46 }
47 
ParserSplitStrategy(const std::vector<int64_t> & split_ratio,const std::vector<std::string> & split_device,SplitMode split_mode)48 std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(const std::vector<int64_t> &split_ratio,
49                                                                         const std::vector<std::string> &split_device,
50                                                                         SplitMode split_mode) {
51   std::unordered_map<std::string, opt::SplitStrategy> split_strategys;
52   if (split_ratio.empty() || kSplitDefaultRatio.empty() || split_device.empty()) {
53     return split_strategys;
54   }
55   if (split_ratio.size() != kSplitDevTypes.size()) {
56     return split_strategys;
57   }
58   std::vector<std::vector<int64_t>> split_feature_map;
59   std::vector<std::vector<int64_t>> split_weight;
60   switch (split_mode) {
61     case SplitN:
62       split_feature_map = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
63       split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
64       break;
65     case SplitH:
66       split_feature_map = {kSplitDefaultRatio, split_ratio, kSplitDefaultRatio, kSplitDefaultRatio};
67       split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
68       break;
69     case SplitCIN:
70       split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
71       split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
72       break;
73     case SplitCOUT:
74       split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
75       split_weight = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
76       break;
77     default:
78       return split_strategys;
79   }
80   opt::Strategies strategys = {split_feature_map, split_weight};
81   for (const auto &supported_parallel_op : kParallelOpNames) {
82     split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode};
83   }
84 
85   return split_strategys;
86 }
87 }  // namespace opt
88 }  // namespace mindspore
89