1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/parallel/split_strategy.h"
19 #include <vector>
20 #include <unordered_map>
21 #include <string>
22 #include "nnacl/op_base.h"
23 #include "src/common/log_util.h"
24
25 namespace mindspore {
26 namespace opt {
27
ApproximateFLOPs(const std::vector<int64_t> & strides,const std::vector<int64_t> & input_shape,const std::vector<int64_t> & weight_shape)28 int64_t ApproximateFLOPs(const std::vector<int64_t> &strides, const std::vector<int64_t> &input_shape,
29 const std::vector<int64_t> &weight_shape) {
30 MS_CHECK_GT(strides.size(), 1, 0);
31 MS_CHECK_GT(input_shape.size(), kInputSize2, 0);
32 MS_CHECK_GT(weight_shape.size(), kInputSize1, 0);
33 int64_t input_h = input_shape.at(kShapeH);
34 int64_t input_w = input_shape.at(kShapeW);
35 int64_t in_c = input_shape.at(kShapeC);
36 int64_t out_c = weight_shape.at(kShapeN);
37 int64_t k_h = weight_shape.at(kShapeH);
38 int64_t k_w = weight_shape.at(kShapeW);
39 int64_t stride_h = strides.at(kIndexH);
40 int64_t stride_w = strides.at(kIndexW);
41 if (stride_h == 0 || stride_w == 0) {
42 MS_LOG(ERROR) << "divisor is zero.";
43 return 0;
44 }
45 return (input_h / stride_h) * (input_w / stride_w) * in_c * k_h * k_w * out_c / kPerFlops;
46 }
47
ParserSplitStrategy(const std::vector<int64_t> & split_ratio,const std::vector<std::string> & split_device,SplitMode split_mode)48 std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(const std::vector<int64_t> &split_ratio,
49 const std::vector<std::string> &split_device,
50 SplitMode split_mode) {
51 std::unordered_map<std::string, opt::SplitStrategy> split_strategys;
52 if (split_ratio.empty() || kSplitDefaultRatio.empty() || split_device.empty()) {
53 return split_strategys;
54 }
55 if (split_ratio.size() != kSplitDevTypes.size()) {
56 return split_strategys;
57 }
58 std::vector<std::vector<int64_t>> split_feature_map;
59 std::vector<std::vector<int64_t>> split_weight;
60 switch (split_mode) {
61 case SplitN:
62 split_feature_map = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
63 split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
64 break;
65 case SplitH:
66 split_feature_map = {kSplitDefaultRatio, split_ratio, kSplitDefaultRatio, kSplitDefaultRatio};
67 split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
68 break;
69 case SplitCIN:
70 split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
71 split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, split_ratio};
72 break;
73 case SplitCOUT:
74 split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
75 split_weight = {split_ratio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
76 break;
77 default:
78 return split_strategys;
79 }
80 opt::Strategies strategys = {split_feature_map, split_weight};
81 for (const auto &supported_parallel_op : kParallelOpNames) {
82 split_strategys[supported_parallel_op.second] = {strategys, kSplitDevTypes, kSplitDevTypes.size(), split_mode};
83 }
84
85 return split_strategys;
86 }
87 } // namespace opt
88 } // namespace mindspore
89