• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/dropout_do_mask_info.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "ir/value.h"
25 #include "pipeline/jit/resource.h"
26 #include "frontend/parallel/auto_parallel/costmodel.h"
27 #include "frontend/parallel/graph_util/node_info.h"
28 #include "frontend/parallel/step_parallel_utils.h"
29 #include "frontend/parallel/device_matrix.h"
30 #include "frontend/parallel/strategy.h"
31 
32 namespace mindspore {
33 namespace parallel {
34 static int64_t SEED_NUM = 1;
35 
CheckStrategy(const StrategyPtr & strategy)36 Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
37   if (strategy == nullptr) {
38     MS_LOG(ERROR) << name_ << ": The strategy is null";
39     return FAILED;
40   }
41 
42   Strategys stra = strategy->GetInputDim();
43   if (stra.size() != 1) {
44     MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1";
45     return FAILED;
46   }
47 
48   if (inputs_shape_.empty()) {
49     MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
50     return FAILED;
51   }
52 
53   // only check the input[0]
54   Shapes input_shape = {inputs_shape_[0]};
55   return CheckStrategyValue(strategy, input_shape);
56 }
57 
InferDevMatrixShape()58 Status DropoutDoMaskInfo::InferDevMatrixShape() {
59   if (strategy_ == nullptr) {
60     MS_LOG(ERROR) << name_ << ": The strategy is null";
61     return FAILED;
62   }
63 
64   Strategys strategy = strategy_->GetInputDim();
65   if (strategy.empty()) {
66     MS_LOG(ERROR) << name_ << ": The strategy is empty";
67     return FAILED;
68   }
69 
70   dev_matrix_shape_ = strategy[0];
71   return SUCCESS;
72 }
73 
InferTensorMap()74 Status DropoutDoMaskInfo::InferTensorMap() {
75   if (inputs_shape_.empty()) {
76     MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
77     return FAILED;
78   }
79 
80   Shape tensor_map_index;
81   size_t size = inputs_shape_[0].size();
82   // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0]
83   for (size_t i = 0; i < size; ++i) {
84     tensor_map_index.push_back(SizeToLong(size - i - 1));
85   }
86 
87   // the input[1] do not need tensor map
88   inputs_tensor_map_.push_back(tensor_map_index);   // input_0
89   outputs_tensor_map_.push_back(tensor_map_index);  // output
90   return SUCCESS;
91 }
92 
SetCostUnderStrategy(const StrategyPtr & strategy)93 Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
94   return SetCostUnderStrategyBase(strategy);
95 }
96 
GenerateOpStrategies(int64_t stage_id)97 std::vector<StrategyPtr> DropoutDoMaskInfo::GenerateOpStrategies(int64_t stage_id) {
98   if (inputs_shape_.empty()) {
99     MS_LOG(EXCEPTION) << name_ << ": The inputs shape is empty";
100   }
101 
102   Shape input0_split(inputs_shape_[0].size(), 1);
103   Shapes splittable_inputs = {input0_split};
104   Shapes used_inputs_shape = {inputs_shape_[0]};
105 
106   std::vector<StrategyPtr> sp_vector;
107   if (GenerateStrategiesForIndependentInputs(stage_id, used_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
108     MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
109   }
110   return sp_vector;
111 }
112 
GenerateBatchStrategies()113 std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() {
114   Dimensions strategy(inputs_shape_[0].size() - 1, 1);
115   (void)strategy.insert(strategy.begin(), stage_device_size_);
116   Strategys strategy_v = {strategy};
117   return std::make_shared<Strategys>(strategy_v);
118 }
119 
Init(const StrategyPtr & strategy)120 Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) {
121   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
122     MS_LOG(ERROR) << name_ << ": Init failed.";
123     return FAILED;
124   }
125 
126   MS_LOG(INFO) << name_ << ": Init success.";
127   return SUCCESS;
128 }
129 
InitForCostModel(const StrategyPtr & strategy)130 Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
131   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
132     MS_LOG(ERROR) << name_ << ": Init for cost model failed";
133     return FAILED;
134   }
135 
136   MS_LOG(INFO) << name_ << ": Init for cost model success";
137   return SUCCESS;
138 }
139 
GetNonMonadInputSize(const CNodePtr & cnode)140 size_t GetNonMonadInputSize(const CNodePtr &cnode) {
141   size_t cnode_non_monad_size = cnode->size();
142   for (auto &input : cnode->inputs()) {
143     if (HasAbstractMonad(input)) {
144       cnode_non_monad_size--;
145     }
146   }
147   return cnode_non_monad_size;
148 }
149 
GetDropoutGenMaskPrim(const CNodePtr & cnode)150 PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
151   MS_EXCEPTION_IF_NULL(cnode);
152   if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
153     MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
154   }
155 
156   AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
157   MS_EXCEPTION_IF_NULL(dropout_gen_mask);
158   if (!dropout_gen_mask->isa<CNode>()) {
159     MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode";
160   }
161 
162   auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
163   size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode);
164   if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
165     MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
166   }
167   if (!IsValueNode<Primitive>(dropout_gen_mask_cnode->input(0))) {
168     MS_LOG(EXCEPTION) << "The input[0] of dropout gen mask cnode is not primitive";
169   }
170 
171   ValueNodePtr value_node = dropout_gen_mask_cnode->input(0)->cast<ValueNodePtr>();
172   MS_EXCEPTION_IF_NULL(value_node);
173   PrimitivePtr prim = value_node->value()->cast<PrimitivePtr>();
174   MS_EXCEPTION_IF_NULL(prim);
175   if (prim->name() != DROPOUT_GEN_MASK) {
176     MS_LOG(EXCEPTION) << "The primitive name is not DropoutGenMask";
177   }
178   return prim;
179 }
180 
SetGenMaskShape(const CNodePtr & cnode,const Shape & input_slice_shape)181 void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
182   MS_EXCEPTION_IF_NULL(cnode);
183   if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
184     MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
185   }
186 
187   AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
188   MS_EXCEPTION_IF_NULL(dropout_gen_mask);
189   if (!dropout_gen_mask->isa<CNode>()) {
190     MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode.";
191   }
192 
193   auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
194   size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode);
195   if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
196     MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
197   }
198 
199   if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(1))) {
200     MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
201   }
202 
203   FuncGraphPtr func_graph = cnode->func_graph();
204   MS_EXCEPTION_IF_NULL(func_graph);
205   FuncGraphManagerPtr manager = func_graph->manager();
206   if (manager == nullptr) {
207     MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr.";
208   }
209   ValuePtr new_shape = MakeValue(input_slice_shape);
210   AnfNodePtr val = NewValueNode(new_shape);
211   (void)manager->Replace(dropout_gen_mask_cnode->input(1), val);
212 }
213 
214 // DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is
215 // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
216 // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
217 // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
GetDropoutGenMaskReplaceOp()218 std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp() {
219   auto cnode = cnode_;
220   std::vector<Operator> replace_ops;
221   MS_EXCEPTION_IF_NULL(cnode);
222   PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
223   MS_EXCEPTION_IF_NULL(prim);
224 
225   if (inputs_tensor_info_.empty()) {
226     MS_LOG(EXCEPTION) << "The tensor info of dropout do mask is empty";
227   }
228 
229   if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
230     MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
231   }
232 
233   if (!cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)->isa<ValueNode>()) {
234     MS_LOG(EXCEPTION) << "The keep prob of dropout do mask is not value node";
235   }
236 
237   ValuePtr keep_prob = GetValueNode(cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX));
238   MS_EXCEPTION_IF_NULL(keep_prob);
239   auto attr = prim->attrs();
240   if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) {
241     MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1";
242   }
243 
244   Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
245   int64_t seed_0 = GetValue<int64_t>(attr[SEED0]);
246   int64_t seed_1 = GetValue<int64_t>(attr[SEED1]);
247   if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) {
248     seed_0 = SEED_NUM;
249     seed_1 = SEED_NUM;
250     SEED_NUM++;
251   } else {
252     SetGenMaskShape(cnode, input_slice_shape);
253     MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape);
254     return replace_ops;
255   }
256   ValuePtr new_shape = MakeValue(input_slice_shape);
257   Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0));
258   Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1));
259   OperatorAttrs attrs = {attr_0, attr_1};
260   Attr param_0 = std::make_pair(SHAPE, new_shape);
261   Attr param_1 = std::make_pair(KEEP_PROB, keep_prob);
262   OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)};
263   OperatorArgs args = std::make_pair(attrs, params);
264   Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)};
265   replace_ops.push_back(replace_op);
266   return replace_ops;
267 }
268 
ReplaceOneOp(const Operator & replace_op,const CNodePtr & node)269 static void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
270   FuncGraphPtr func_graph = node->func_graph();
271   MS_EXCEPTION_IF_NULL(func_graph);
272   FuncGraphManagerPtr manager = func_graph->manager();
273   if (manager == nullptr) {
274     MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
275   }
276   std::string instance_name = CreateInstanceName(node, 0);
277   std::vector<AnfNodePtr> replace_input;
278   replace_input = ReplaceOpInput(replace_op, instance_name, node);
279   if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
280     replace_input.push_back(node->input(3));
281   }
282   CNodePtr replace_node = func_graph->NewCNode(replace_input);
283   MS_EXCEPTION_IF_NULL(replace_node);
284   ScopePtr scope = node->scope();
285   MS_EXCEPTION_IF_NULL(scope);
286   replace_node->set_scope(scope);
287   replace_node->set_in_forward_flag(true);
288   replace_input[0]->set_scope(scope);
289   PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
290   PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
291   SetUserAttrs(origin_prim->attrs(), prim);
292   (void)manager->Replace(node, replace_node);
293 }
294 
ReplaceNodeInputOrAttrs()295 void DropoutDoMaskInfo::ReplaceNodeInputOrAttrs() {
296   auto cnode = cnode_;
297   MS_EXCEPTION_IF_NULL(cnode);
298   std::vector<Operator> replace_op = GetDropoutGenMaskReplaceOp();
299   if (replace_op.empty()) {
300     MS_LOG(DEBUG) << name_ << ": No need to replace dropout_gen_mask";
301     return;
302   }
303   if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
304     MS_LOG(EXCEPTION) << name_ << ": The size of drop out do mask cnode's input is not "
305                       << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
306   }
307   ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
308 }
309 }  // namespace parallel
310 }  // namespace mindspore
311