1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/dropout_do_mask_info.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23
24 #include "ir/value.h"
25 #include "pipeline/jit/resource.h"
26 #include "frontend/parallel/auto_parallel/costmodel.h"
27 #include "frontend/parallel/graph_util/node_info.h"
28 #include "frontend/parallel/step_parallel_utils.h"
29 #include "frontend/parallel/device_matrix.h"
30 #include "frontend/parallel/strategy.h"
31
32 namespace mindspore {
33 namespace parallel {
34 static int64_t SEED_NUM = 1;
35
CheckStrategy(const StrategyPtr & strategy)36 Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
37 if (strategy == nullptr) {
38 MS_LOG(ERROR) << name_ << ": The strategy is null";
39 return FAILED;
40 }
41
42 Strategys stra = strategy->GetInputDim();
43 if (stra.size() != 1) {
44 MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1";
45 return FAILED;
46 }
47
48 if (inputs_shape_.empty()) {
49 MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
50 return FAILED;
51 }
52
53 // only check the input[0]
54 Shapes input_shape = {inputs_shape_[0]};
55 return CheckStrategyValue(strategy, input_shape);
56 }
57
InferDevMatrixShape()58 Status DropoutDoMaskInfo::InferDevMatrixShape() {
59 if (strategy_ == nullptr) {
60 MS_LOG(ERROR) << name_ << ": The strategy is null";
61 return FAILED;
62 }
63
64 Strategys strategy = strategy_->GetInputDim();
65 if (strategy.empty()) {
66 MS_LOG(ERROR) << name_ << ": The strategy is empty";
67 return FAILED;
68 }
69
70 dev_matrix_shape_ = strategy[0];
71 return SUCCESS;
72 }
73
InferTensorMap()74 Status DropoutDoMaskInfo::InferTensorMap() {
75 if (inputs_shape_.empty()) {
76 MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
77 return FAILED;
78 }
79
80 Shape tensor_map_index;
81 size_t size = inputs_shape_[0].size();
82 // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0]
83 for (size_t i = 0; i < size; ++i) {
84 tensor_map_index.push_back(SizeToLong(size - i - 1));
85 }
86
87 // the input[1] do not need tensor map
88 inputs_tensor_map_.push_back(tensor_map_index); // input_0
89 outputs_tensor_map_.push_back(tensor_map_index); // output
90 return SUCCESS;
91 }
92
SetCostUnderStrategy(const StrategyPtr & strategy)93 Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
94 return SetCostUnderStrategyBase(strategy);
95 }
96
GenerateOpStrategies(int64_t stage_id)97 std::vector<StrategyPtr> DropoutDoMaskInfo::GenerateOpStrategies(int64_t stage_id) {
98 if (inputs_shape_.empty()) {
99 MS_LOG(EXCEPTION) << name_ << ": The inputs shape is empty";
100 }
101
102 Shape input0_split(inputs_shape_[0].size(), 1);
103 Shapes splittable_inputs = {input0_split};
104 Shapes used_inputs_shape = {inputs_shape_[0]};
105
106 std::vector<StrategyPtr> sp_vector;
107 if (GenerateStrategiesForIndependentInputs(stage_id, used_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
108 MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
109 }
110 return sp_vector;
111 }
112
GenerateBatchStrategies()113 std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() {
114 Dimensions strategy(inputs_shape_[0].size() - 1, 1);
115 (void)strategy.insert(strategy.begin(), stage_device_size_);
116 Strategys strategy_v = {strategy};
117 return std::make_shared<Strategys>(strategy_v);
118 }
119
Init(const StrategyPtr & strategy)120 Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) {
121 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
122 MS_LOG(ERROR) << name_ << ": Init failed.";
123 return FAILED;
124 }
125
126 MS_LOG(INFO) << name_ << ": Init success.";
127 return SUCCESS;
128 }
129
InitForCostModel(const StrategyPtr & strategy)130 Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
131 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
132 MS_LOG(ERROR) << name_ << ": Init for cost model failed";
133 return FAILED;
134 }
135
136 MS_LOG(INFO) << name_ << ": Init for cost model success";
137 return SUCCESS;
138 }
139
GetNonMonadInputSize(const CNodePtr & cnode)140 size_t GetNonMonadInputSize(const CNodePtr &cnode) {
141 size_t cnode_non_monad_size = cnode->size();
142 for (auto &input : cnode->inputs()) {
143 if (HasAbstractMonad(input)) {
144 cnode_non_monad_size--;
145 }
146 }
147 return cnode_non_monad_size;
148 }
149
GetDropoutGenMaskPrim(const CNodePtr & cnode)150 PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
151 MS_EXCEPTION_IF_NULL(cnode);
152 if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
153 MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
154 }
155
156 AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
157 MS_EXCEPTION_IF_NULL(dropout_gen_mask);
158 if (!dropout_gen_mask->isa<CNode>()) {
159 MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode";
160 }
161
162 auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
163 size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode);
164 if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
165 MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
166 }
167 if (!IsValueNode<Primitive>(dropout_gen_mask_cnode->input(0))) {
168 MS_LOG(EXCEPTION) << "The input[0] of dropout gen mask cnode is not primitive";
169 }
170
171 ValueNodePtr value_node = dropout_gen_mask_cnode->input(0)->cast<ValueNodePtr>();
172 MS_EXCEPTION_IF_NULL(value_node);
173 PrimitivePtr prim = value_node->value()->cast<PrimitivePtr>();
174 MS_EXCEPTION_IF_NULL(prim);
175 if (prim->name() != DROPOUT_GEN_MASK) {
176 MS_LOG(EXCEPTION) << "The primitive name is not DropoutGenMask";
177 }
178 return prim;
179 }
180
SetGenMaskShape(const CNodePtr & cnode,const Shape & input_slice_shape)181 void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
182 MS_EXCEPTION_IF_NULL(cnode);
183 if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
184 MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
185 }
186
187 AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
188 MS_EXCEPTION_IF_NULL(dropout_gen_mask);
189 if (!dropout_gen_mask->isa<CNode>()) {
190 MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode.";
191 }
192
193 auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
194 size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode);
195 if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
196 MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
197 }
198
199 if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(1))) {
200 MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
201 }
202
203 FuncGraphPtr func_graph = cnode->func_graph();
204 MS_EXCEPTION_IF_NULL(func_graph);
205 FuncGraphManagerPtr manager = func_graph->manager();
206 if (manager == nullptr) {
207 MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr.";
208 }
209 ValuePtr new_shape = MakeValue(input_slice_shape);
210 AnfNodePtr val = NewValueNode(new_shape);
211 (void)manager->Replace(dropout_gen_mask_cnode->input(1), val);
212 }
213
214 // DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is
215 // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
216 // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
217 // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
GetDropoutGenMaskReplaceOp()218 std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp() {
219 auto cnode = cnode_;
220 std::vector<Operator> replace_ops;
221 MS_EXCEPTION_IF_NULL(cnode);
222 PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
223 MS_EXCEPTION_IF_NULL(prim);
224
225 if (inputs_tensor_info_.empty()) {
226 MS_LOG(EXCEPTION) << "The tensor info of dropout do mask is empty";
227 }
228
229 if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
230 MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
231 }
232
233 if (!cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)->isa<ValueNode>()) {
234 MS_LOG(EXCEPTION) << "The keep prob of dropout do mask is not value node";
235 }
236
237 ValuePtr keep_prob = GetValueNode(cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX));
238 MS_EXCEPTION_IF_NULL(keep_prob);
239 auto attr = prim->attrs();
240 if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) {
241 MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1";
242 }
243
244 Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
245 int64_t seed_0 = GetValue<int64_t>(attr[SEED0]);
246 int64_t seed_1 = GetValue<int64_t>(attr[SEED1]);
247 if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) {
248 seed_0 = SEED_NUM;
249 seed_1 = SEED_NUM;
250 SEED_NUM++;
251 } else {
252 SetGenMaskShape(cnode, input_slice_shape);
253 MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape);
254 return replace_ops;
255 }
256 ValuePtr new_shape = MakeValue(input_slice_shape);
257 Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0));
258 Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1));
259 OperatorAttrs attrs = {attr_0, attr_1};
260 Attr param_0 = std::make_pair(SHAPE, new_shape);
261 Attr param_1 = std::make_pair(KEEP_PROB, keep_prob);
262 OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)};
263 OperatorArgs args = std::make_pair(attrs, params);
264 Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)};
265 replace_ops.push_back(replace_op);
266 return replace_ops;
267 }
268
ReplaceOneOp(const Operator & replace_op,const CNodePtr & node)269 static void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
270 FuncGraphPtr func_graph = node->func_graph();
271 MS_EXCEPTION_IF_NULL(func_graph);
272 FuncGraphManagerPtr manager = func_graph->manager();
273 if (manager == nullptr) {
274 MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
275 }
276 std::string instance_name = CreateInstanceName(node, 0);
277 std::vector<AnfNodePtr> replace_input;
278 replace_input = ReplaceOpInput(replace_op, instance_name, node);
279 if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
280 replace_input.push_back(node->input(3));
281 }
282 CNodePtr replace_node = func_graph->NewCNode(replace_input);
283 MS_EXCEPTION_IF_NULL(replace_node);
284 ScopePtr scope = node->scope();
285 MS_EXCEPTION_IF_NULL(scope);
286 replace_node->set_scope(scope);
287 replace_node->set_in_forward_flag(true);
288 replace_input[0]->set_scope(scope);
289 PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
290 PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
291 SetUserAttrs(origin_prim->attrs(), prim);
292 (void)manager->Replace(node, replace_node);
293 }
294
ReplaceNodeInputOrAttrs()295 void DropoutDoMaskInfo::ReplaceNodeInputOrAttrs() {
296 auto cnode = cnode_;
297 MS_EXCEPTION_IF_NULL(cnode);
298 std::vector<Operator> replace_op = GetDropoutGenMaskReplaceOp();
299 if (replace_op.empty()) {
300 MS_LOG(DEBUG) << name_ << ": No need to replace dropout_gen_mask";
301 return;
302 }
303 if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
304 MS_LOG(EXCEPTION) << name_ << ": The size of drop out do mask cnode's input is not "
305 << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
306 }
307 ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
308 }
309 } // namespace parallel
310 } // namespace mindspore
311