1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/layer_norm_info.h"
18 #include <algorithm>
19 #include <vector>
20 #include "frontend/parallel/device_matrix.h"
21 #include "frontend/parallel/strategy.h"
22
23 namespace mindspore {
24 namespace parallel {
25 // the layernorm has three outputs
26 // if the shape of input is [A, B, C, D], the shape of first output is [A, B, C, D]
27 // if the begin-norm-axis is 0, the shape of second output is: [1, 1, 1, 1]
28 // if the begin-norm-axis is 1, the shape of second output is: [A, 1, 1, 1]
29 // if the begin-norm-axis is 2, the shape of second output is: [A, B, 1, 1]
30 // if the begin-norm-axis is 3, the shape of second output is: [A, B, C, 1]
31 // the shape of third output is the same as the shape of second output
GetAttrs()32 Status LayerNormInfo::GetAttrs() {
33 auto iter = attrs_.find(BEGIN_NORM_AXIS);
34 if (iter == attrs_.end()) {
35 MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis";
36 return FAILED;
37 }
38 if ((iter->second == nullptr) || !iter->second->isa<Int64Imm>()) {
39 MS_LOG(ERROR) << name_ << ": The axis type is not int64_t";
40 return FAILED;
41 }
42
43 int64_t dim = SizeToLong(inputs_shape_[0].size());
44 auto axis = GetValue<int64_t>(iter->second);
45 if ((axis >= dim) || (axis < -dim)) {
46 MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << (-dim) << ", " << (dim - 1) << "]";
47 return FAILED;
48 }
49
50 if (axis < 0) {
51 axis = axis + dim;
52 }
53 begin_norm_axis_ = LongToSize(axis);
54 return SUCCESS;
55 }
56
CheckStrategy(const StrategyPtr & strategy)57 Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
58 MS_EXCEPTION_IF_NULL(strategy);
59 Strategys stra = strategy->GetInputDim();
60 if (stra.size() != LAYER_NORM_INPUT_SIZE) {
61 MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
62 return FAILED;
63 }
64
65 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
66 MS_LOG(ERROR) << name_ << ": Invalid strategy value";
67 return FAILED;
68 }
69
70 Dimensions input_strategy = stra[LAYER_NORM_INPUT_INDEX];
71 Dimensions gamma_strategy = stra[LAYER_NORM_GAMMA_INDEX];
72 Dimensions beta_strategy = stra[LAYER_NORM_BETA_INDEX];
73 if (begin_norm_axis_ >= input_strategy.size()) {
74 MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_;
75 return FAILED;
76 }
77 // check input strategy
78 for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) {
79 if (input_strategy[i] != NO_SPLIT_STRATEGY) {
80 MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy);
81 return FAILED;
82 }
83 }
84
85 // check gamma and beta strategy
86 if ((gamma_strategy.size() > input_strategy.size()) || (beta_strategy.size() > input_strategy.size())) {
87 MS_LOG(ERROR) << name_ << " : The strategy size of gamma or beta is lager than input strategy";
88 return FAILED;
89 }
90
91 size_t gamma_diff = input_strategy.size() - gamma_strategy.size();
92 for (size_t j = 0; j < gamma_strategy.size(); ++j) {
93 if (gamma_strategy[j] != input_strategy[gamma_diff + j]) {
94 MS_LOG(ERROR) << name_ << ": Invalid gamma strategy " << ShapeToString(gamma_strategy);
95 return FAILED;
96 }
97 }
98
99 size_t beta_diff = input_strategy.size() - beta_strategy.size();
100 for (size_t k = 0; k < beta_strategy.size(); ++k) {
101 if (beta_strategy[k] != input_strategy[beta_diff + k]) {
102 MS_LOG(ERROR) << name_ << ": Invalid beta strategy " << ShapeToString(beta_strategy);
103 return FAILED;
104 }
105 }
106 return SUCCESS;
107 }
108
InferDevMatrixShape()109 Status LayerNormInfo::InferDevMatrixShape() {
110 if (strategy_ == nullptr) {
111 MS_LOG(ERROR) << name_ << ": The strategy is null";
112 return FAILED;
113 }
114 Strategys stra = strategy_->GetInputDim();
115 if (stra.empty()) {
116 MS_LOG(ERROR) << name_ << ": The strategy is empty";
117 return FAILED;
118 }
119 dev_matrix_shape_ = stra[0];
120 return SUCCESS;
121 }
122
CreateInputTensorMap(size_t input_index)123 Status LayerNormInfo::CreateInputTensorMap(size_t input_index) {
124 if (inputs_shape_.size() <= input_index) {
125 MS_LOG(ERROR) << name_ << ": Invalid index" << input_index;
126 return FAILED;
127 }
128 Shape shape = inputs_shape_[input_index];
129 Shape tensor_map;
130 for (size_t i = 0; i < shape.size(); ++i) {
131 tensor_map.push_back(SizeToLong(shape.size() - i - 1));
132 }
133 inputs_tensor_map_.push_back(tensor_map);
134 return SUCCESS;
135 }
136
InferTensorMap()137 Status LayerNormInfo::InferTensorMap() {
138 if ((CreateInputTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) ||
139 (CreateInputTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
140 (CreateInputTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
141 MS_LOG(ERROR) << name_ << ": Create input tensor map failed";
142 return FAILED;
143 }
144
145 Shape first_output_tensor_map = inputs_tensor_map_[0];
146 Shape second_output_tensor_map = first_output_tensor_map;
147 for (size_t i = begin_norm_axis_; i < second_output_tensor_map.size(); ++i) {
148 second_output_tensor_map[i] = MAP_NONE;
149 }
150 Shape third_output_tensor_map = second_output_tensor_map;
151
152 outputs_tensor_map_.push_back(first_output_tensor_map);
153 outputs_tensor_map_.push_back(second_output_tensor_map);
154 outputs_tensor_map_.push_back(third_output_tensor_map);
155 return SUCCESS;
156 }
157
InferAsLossDivisor()158 Status LayerNormInfo::InferAsLossDivisor() {
159 if (outputs_tensor_map_.size() != LAYER_NORM_INPUT_SIZE) {
160 MS_LOG(ERROR) << name_ << ": The size of outputs tensor map " << outputs_tensor_map_.size() << " is error";
161 return FAILED;
162 }
163 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
164 MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
165 << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
166 << ", as_loss_divisor_ is " << as_loss_divisor_;
167 return SUCCESS;
168 }
169
SetCostUnderStrategy(const StrategyPtr & strategy)170 Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
171
GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> & sp_vector)172 Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector) {
173 if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) {
174 MS_LOG(ERROR) << name_ << ": The dimension of gamma or beta is lager than input";
175 return FAILED;
176 }
177
178 size_t gamma_diff = input_shape_.size() - gamma_shape_.size();
179 size_t beta_diff = input_shape_.size() - beta_shape_.size();
180 for (auto &sp : sp_vector) {
181 if ((sp == nullptr) || sp->GetInputDim().empty()) {
182 MS_LOG(ERROR) << name_ << ": Invalid strategy";
183 return FAILED;
184 }
185 Strategys tmp_strategy;
186 Dimensions input_strategy = sp->GetInputDim()[0];
187 Dimensions gamma_strategy = input_strategy;
188 (void)gamma_strategy.erase(gamma_strategy.begin(),
189 gamma_strategy.begin() + static_cast<different_type>(gamma_diff));
190 Dimensions beta_strategy = input_strategy;
191 (void)beta_strategy.erase(beta_strategy.begin(), beta_strategy.begin() + static_cast<different_type>(beta_diff));
192
193 // reset the strategy
194 tmp_strategy.push_back(input_strategy);
195 tmp_strategy.push_back(gamma_strategy);
196 tmp_strategy.push_back(beta_strategy);
197 sp->ResetInputs(tmp_strategy);
198 }
199 return SUCCESS;
200 }
201
GenerateOpStrategies(int64_t stage_id)202 std::vector<StrategyPtr> LayerNormInfo::GenerateOpStrategies(int64_t stage_id) {
203 if (InitShapes() != SUCCESS) {
204 MS_LOG(EXCEPTION) << name_ << ": Init shapes failed";
205 }
206 Shape input_split(input_shape_.size(), SPLIT_FLAG);
207 if (begin_norm_axis_ >= input_split.size()) {
208 MS_LOG(EXCEPTION) << name_ << ": Invalid begin norm axis " << begin_norm_axis_;
209 }
210
211 // Can not split the dimensions from begin norm axis
212 for (size_t i = begin_norm_axis_; i < input_split.size(); ++i) {
213 input_split[i] = NO_SPLIT_FLAG;
214 }
215
216 // Generate strategy for input
217 Shapes splittable_inputs = {input_split};
218 Shapes tmp_inputs_shape = {input_shape_};
219 std::vector<StrategyPtr> sp_vector;
220 if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
221 MS_LOG(EXCEPTION) << name_ << ": Generate input strategy failed";
222 }
223
224 // Generate the strategies for gamma and beta
225 if (GenerateGammaAndBetaStrategies(sp_vector) != SUCCESS) {
226 MS_LOG(EXCEPTION) << name_ << ": Generate gamma and beta strategies failed";
227 }
228
229 return sp_vector;
230 }
231
InitShapes()232 Status LayerNormInfo::InitShapes() {
233 if (inputs_shape_.size() != LAYER_NORM_INPUT_SIZE) {
234 MS_LOG(ERROR) << name_ << ": Invalid inputs size";
235 return FAILED;
236 }
237 input_shape_ = inputs_shape_[LAYER_NORM_INPUT_INDEX];
238 gamma_shape_ = inputs_shape_[LAYER_NORM_GAMMA_INDEX];
239 beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX];
240 return SUCCESS;
241 }
242
Init(const StrategyPtr & strategy)243 Status LayerNormInfo::Init(const StrategyPtr &strategy) {
244 if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) {
245 MS_LOG(ERROR) << name_ << ": Init failed";
246 return FAILED;
247 }
248 MS_LOG(INFO) << name_ << ": Init success";
249 return SUCCESS;
250 }
251
InitForCostModel(const StrategyPtr & strategy)252 Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) {
253 if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) {
254 MS_LOG(ERROR) << name_ << ": Init for cost model failed";
255 return FAILED;
256 }
257
258 MS_LOG(INFO) << name_ << ": Init for cost model success";
259 return SUCCESS;
260 }
261 } // namespace parallel
262 } // namespace mindspore
263