1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/parallel/split_strategy.h"
18 #include <vector>
19 #include <unordered_map>
20 #include <string>
21 #include "nnacl/op_base.h"
22
23 namespace mindspore {
24 namespace opt {
25
ApproximateFLOPs(const std::vector<int64_t> & strides,const std::vector<int64_t> & input_shape,const std::vector<int64_t> & weight_shape)26 int64_t ApproximateFLOPs(const std::vector<int64_t> &strides, const std::vector<int64_t> &input_shape,
27 const std::vector<int64_t> &weight_shape) {
28 MS_CHECK_GT(strides.size(), 1, 0);
29 MS_CHECK_GT(input_shape.size(), kInputSize2, 0);
30 MS_CHECK_GT(weight_shape.size(), kInputSize1, 0);
31 int64_t input_h = input_shape.at(kShapeH);
32 int64_t input_w = input_shape.at(kShapeW);
33 int64_t in_c = input_shape.at(kShapeC);
34 int64_t out_c = weight_shape.at(kShapeN);
35 int64_t k_h = weight_shape.at(kShapeH);
36 int64_t k_w = weight_shape.at(kShapeW);
37 int64_t stride_h = strides.at(kIndexH);
38 int64_t stride_w = strides.at(kIndexW);
39 if (stride_h == 0 || stride_w == 0) {
40 MS_LOG(ERROR) << "divisor is zero.";
41 return 0;
42 }
43 return (input_h / stride_h) * (input_w / stride_w) * in_c * k_h * k_w * out_c / kPerFlops;
44 }
45
ParserSplitStrategy(const std::vector<int64_t> & split_ratio,const std::vector<std::string> & split_device,SplitMode split_mode)46 std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(const std::vector<int64_t> &split_ratio,
47 const std::vector<std::string> &split_device,
48 SplitMode split_mode) {
49 std::unordered_map<std::string, opt::SplitStrategy> split_strategys;
50 if (split_ratio.empty() || kSplitDefaultRatio.empty() || split_device.empty()) {
51 return split_strategys;
52 }
53 if (split_ratio.size() != kSplitDevTypes.size()) {
54 return split_strategys;
55 }
56 std::vector<std::vector<int64_t>> split_feature_map;
57 std::vector<std::vector<int64_t>> split_weight;
58 switch (split_mode) {
59 case SplitN:
60 split_feature_map = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
61 split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
62 break;
63 case SplitH:
64 split_feature_map = {kSplitDefaultRatio, split_ratio, kSplitDefaultRatio, kSplitDefaultRatio};
65 split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
66 break;
67 case SplitCIN:
68 split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
69 split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
70 break;
71 case SplitCOUT:
72 split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
73 split_weight = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
74 break;
75 default:
76 return split_strategys;
77 }
78 opt::Strategys strategys = {split_feature_map, split_weight};
79 for (const auto &supported_parallel_op : kParallelOpNames) {
80 split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode};
81 }
82
83 return split_strategys;
84 }
85 } // namespace opt
86 } // namespace mindspore
87