1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/arithmetic_info.h"
18
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22
23 #include "frontend/parallel/device_matrix.h"
24 #include "frontend/parallel/strategy.h"
25 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
26
27 namespace mindspore {
28 namespace parallel {
ExpendShape(const Shape & bigger_size_shape,Shape smaller_size_shape)29 Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) {
30 size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size();
31 for (size_t num = 0; num < insert_num; ++num) {
32 (void)smaller_size_shape.insert(smaller_size_shape.begin(), 1);
33 }
34 return smaller_size_shape;
35 }
36
InferExpendShape()37 Shapes ArithmeticBase::InferExpendShape() {
38 Shape input_a_shape = inputs_shape_.at(0);
39 Shape input_b_shape = inputs_shape_.at(1);
40 Shapes input_shapes;
41 size_t input_a_size = input_a_shape.size();
42 size_t input_b_size = input_b_shape.size();
43 if (input_a_size > input_b_size) {
44 input_shapes.push_back(input_a_shape);
45 input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape));
46 } else if (input_a_size < input_b_size) {
47 input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape));
48 input_shapes.push_back(input_b_shape);
49 } else {
50 input_shapes.push_back(input_a_shape);
51 input_shapes.push_back(input_b_shape);
52 }
53 return input_shapes;
54 }
55
ExpendStrategy(const StrategyPtr & strategy)56 Strategys ExpendStrategy(const StrategyPtr &strategy) {
57 Strategys expend_strategy;
58 Strategys stra = strategy->GetInputDim();
59 Dimensions sub_a_strategy = stra.at(0);
60 Dimensions sub_b_strategy = stra.at(1);
61 size_t input_a_size = sub_a_strategy.size();
62 size_t input_b_size = sub_b_strategy.size();
63 if (input_a_size > input_b_size) {
64 expend_strategy.push_back(sub_a_strategy);
65 expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy));
66 } else if (input_a_size < input_b_size) {
67 expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy));
68 expend_strategy.push_back(sub_b_strategy);
69 } else {
70 expend_strategy = stra;
71 }
72 return expend_strategy;
73 }
74
CheckStrategy(const StrategyPtr & strategy)75 Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
76 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
77 MS_LOG(ERROR) << name_ << " : Invalid strategy.";
78 return FAILED;
79 }
80 Shapes input_shapes = InferExpendShape();
81 Strategys expend_strategy = ExpendStrategy(strategy);
82 Dimensions sub_a_strategy = expend_strategy.at(0);
83 Dimensions sub_b_strategy = expend_strategy.at(1);
84 Shape input_a_shape = input_shapes.at(0);
85 Shape input_b_shape = input_shapes.at(1);
86
87 for (size_t i = 0; i < input_a_shape.size(); ++i) {
88 if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != 1) && (input_b_shape[i] != 1)) {
89 MS_LOG(ERROR) << name_ << " : Invalid strategy.";
90 return FAILED;
91 }
92 }
93 return SUCCESS;
94 }
95
InferDevMatrixShape()96 Status ArithmeticBase::InferDevMatrixShape() {
97 Strategys expend_strategy = ExpendStrategy(strategy_);
98 Dimensions sub_a_strategy = expend_strategy.at(0);
99 Dimensions sub_b_strategy = expend_strategy.at(1);
100 Shape dev_shape;
101 for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
102 if (sub_a_strategy[i] != sub_b_strategy[i]) {
103 dev_shape.push_back(sub_a_strategy[i] * sub_b_strategy[i]);
104 } else {
105 dev_shape.push_back(sub_a_strategy[i]);
106 }
107 }
108 dev_matrix_shape_ = dev_shape;
109
110 return SUCCESS;
111 }
112
SetExpendTensorMap(const Shape & strategy,const Shape & dev_matrix_shape)113 TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) {
114 TensorMap tensor_map_index;
115 for (size_t i = 0; i < strategy.size(); ++i) {
116 if (strategy[i] == dev_matrix_shape[i]) {
117 tensor_map_index.push_back((int64_t)(LAST_INDEX(strategy.size()) - i));
118 } else {
119 tensor_map_index.push_back(-1);
120 }
121 }
122 return tensor_map_index;
123 }
124
SetTensorMap(const Shape & strategy_expend,const Shape & dev_matrix_shape,const Shape & strategy)125 TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) {
126 TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape);
127 size_t dev_matrix_size = dev_matrix_shape.size();
128 size_t strategy_size = strategy.size();
129 if (dev_matrix_size != strategy_size) {
130 (void)expend_map.erase(expend_map.begin(),
131 expend_map.begin() + static_cast<different_type>(dev_matrix_size - strategy_size));
132 }
133 return expend_map;
134 }
135
ReComputeBatchSplitFlagList()136 void ArithmeticBase::ReComputeBatchSplitFlagList() {
137 Shapes expend_shapes = InferExpendShape();
138 Shape expend_a_shape = expend_shapes.at(0);
139 Shape expend_b_shape = expend_shapes.at(1);
140 if (expend_a_shape.size() != expend_b_shape.size()) {
141 MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong.";
142 }
143 if (expend_a_shape.empty()) {
144 split_flag_list_[0] = false;
145 split_flag_list_[1] = false;
146 return;
147 }
148 (expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false);
149 (expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false);
150 }
151
InferTensorMap()152 Status ArithmeticBase::InferTensorMap() {
153 Shape tensor_map_index;
154 Strategys expend_strategy = ExpendStrategy(strategy_);
155 Dimensions sub_a_expend_strategy = expend_strategy.at(0);
156 Dimensions sub_b_expend_strategy = expend_strategy.at(1);
157 Strategys stra = strategy_->GetInputDim();
158 Dimensions sub_a_strategy = stra.at(0);
159 Dimensions sub_b_strategy = stra.at(1);
160 for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
161 tensor_map_index.push_back((int64_t)(LAST_INDEX(sub_a_expend_strategy.size()) - i));
162 }
163
164 Shape dev_shape;
165 for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
166 if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) {
167 dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]);
168 } else {
169 dev_shape.push_back(sub_a_expend_strategy[i]);
170 }
171 }
172 inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy));
173 inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy));
174 outputs_tensor_map_.push_back(tensor_map_index);
175
176 return SUCCESS;
177 }
178
SetCostUnderStrategy(const StrategyPtr & strategy)179 Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
180
GenerateOpStrategies(int64_t stage_id)181 std::vector<StrategyPtr> ArithmeticBase::GenerateOpStrategies(int64_t stage_id) {
182 Shape input0_split(inputs_shape_[0].size(), 1);
183 Shape input1_split(inputs_shape_[1].size(), 1);
184 Shapes splittable_inputs = {input0_split, input1_split};
185
186 std::vector<StrategyPtr> sp_vector;
187 if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
188 MS_LOG(EXCEPTION) << name_ << " : Generate strategies with broadcast failed.";
189 }
190 MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success.";
191
192 return sp_vector;
193 }
194
Init(const StrategyPtr & strategy)195 Status ArithmeticBase::Init(const StrategyPtr &strategy) {
196 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
197 MS_LOG(ERROR) << name_ << " : Init failed.";
198 return FAILED;
199 }
200 MS_LOG(INFO) << name_ << " : Init success.";
201 return SUCCESS;
202 }
203
InitForCostModel(const StrategyPtr & strategy)204 Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) {
205 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
206 MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
207 return FAILED;
208 }
209
210 MS_LOG(INFO) << name_ << " : Init for cost model success.";
211 return SUCCESS;
212 }
213 } // namespace parallel
214 } // namespace mindspore
215