• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/parallel/split_strategy.h"
18 #include <vector>
19 #include <unordered_map>
20 #include <string>
21 #include "nnacl/op_base.h"
22 
23 namespace mindspore {
24 namespace opt {
25 
ApproximateFLOPs(const std::vector<int64_t> & strides,const std::vector<int64_t> & input_shape,const std::vector<int64_t> & weight_shape)26 int64_t ApproximateFLOPs(const std::vector<int64_t> &strides, const std::vector<int64_t> &input_shape,
27                          const std::vector<int64_t> &weight_shape) {
28   MS_CHECK_GT(strides.size(), 1, 0);
29   MS_CHECK_GT(input_shape.size(), kInputSize2, 0);
30   MS_CHECK_GT(weight_shape.size(), kInputSize1, 0);
31   int64_t input_h = input_shape.at(kShapeH);
32   int64_t input_w = input_shape.at(kShapeW);
33   int64_t in_c = input_shape.at(kShapeC);
34   int64_t out_c = weight_shape.at(kShapeN);
35   int64_t k_h = weight_shape.at(kShapeH);
36   int64_t k_w = weight_shape.at(kShapeW);
37   int64_t stride_h = strides.at(kIndexH);
38   int64_t stride_w = strides.at(kIndexW);
39   if (stride_h == 0 || stride_w == 0) {
40     MS_LOG(ERROR) << "divisor is zero.";
41     return 0;
42   }
43   return (input_h / stride_h) * (input_w / stride_w) * in_c * k_h * k_w * out_c / kPerFlops;
44 }
45 
ParserSplitStrategy(const std::vector<int64_t> & split_ratio,const std::vector<std::string> & split_device,SplitMode split_mode)46 std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(const std::vector<int64_t> &split_ratio,
47                                                                         const std::vector<std::string> &split_device,
48                                                                         SplitMode split_mode) {
49   std::unordered_map<std::string, opt::SplitStrategy> split_strategys;
50   if (split_ratio.empty() || kSplitDefaultRatio.empty() || split_device.empty()) {
51     return split_strategys;
52   }
53   if (split_ratio.size() != kSplitDevTypes.size()) {
54     return split_strategys;
55   }
56   std::vector<std::vector<int64_t>> split_feature_map;
57   std::vector<std::vector<int64_t>> split_weight;
58   switch (split_mode) {
59     case SplitN:
60       split_feature_map = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
61       split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
62       break;
63     case SplitH:
64       split_feature_map = {kSplitDefaultRatio, split_ratio, kSplitDefaultRatio, kSplitDefaultRatio};
65       split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
66       break;
67     case SplitCIN:
68       split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
69       split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
70       break;
71     case SplitCOUT:
72       split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
73       split_weight = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
74       break;
75     default:
76       return split_strategys;
77   }
78   opt::Strategys strategys = {split_feature_map, split_weight};
79   for (const auto &supported_parallel_op : kParallelOpNames) {
80     split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode};
81   }
82 
83   return split_strategys;
84 }
85 }  // namespace opt
86 }  // namespace mindspore
87