• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h"
17 #include <memory>
18 #include <algorithm>
19 #include "backend/optimizer/common/helper.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "utils/utils.h"
22 #include "utils/trace_base.h"
23 #include "runtime/device/ascend/lic_manager.h"
24 
25 namespace mindspore {
26 namespace opt {
27 namespace {
28 constexpr size_t kReplaceOutputIndex0 = 3;
29 constexpr size_t kReplaceOutputIndex1 = 4;
IsC(const BaseRef & n)30 bool IsC(const BaseRef &n) {
31   if (utils::isa<AnfNodePtr>(n)) {
32     AnfNodePtr in = utils::cast<AnfNodePtr>(n);
33     MS_EXCEPTION_IF_NULL(in);
34     return in->isa<ValueNode>();
35   }
36   return false;
37 }
38 
GetBNOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & bn,std::vector<AnfNodePtr> * bn_outputs)39 void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
40   MS_EXCEPTION_IF_NULL(func_graph);
41   MS_EXCEPTION_IF_NULL(bn);
42   MS_EXCEPTION_IF_NULL(bn_outputs);
43   auto manager = func_graph->manager();
44   MS_EXCEPTION_IF_NULL(manager);
45   if (manager->node_users().find(bn) == manager->node_users().end()) {
46     MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"
47                       << " trace: " << trace::DumpSourceLines(bn);
48   }
49   for (const auto &node_index : manager->node_users()[bn]) {
50     const AnfNodePtr &output = node_index.first;
51     MS_EXCEPTION_IF_NULL(output);
52     bn_outputs->push_back(output);
53   }
54 }
55 }  // namespace
56 
GetFactor(const EquivPtr & equiv) const57 ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const {
58   MS_EXCEPTION_IF_NULL(equiv);
59   auto constant_input = GetAnfNodeByVar(equiv, constant_input0_var_);
60   MS_EXCEPTION_IF_NULL(constant_input);
61   if (!constant_input->isa<ValueNode>()) {
62     return nullptr;
63   }
64   auto value_node = constant_input->cast<ValueNodePtr>();
65   MS_EXCEPTION_IF_NULL(value_node);
66   auto value = value_node->value();
67   MS_EXCEPTION_IF_NULL(value);
68   if (!value->isa<tensor::Tensor>()) {
69     return nullptr;
70   }
71   auto tensor_ptr = value->cast<tensor::TensorPtr>();
72   MS_EXCEPTION_IF_NULL(tensor_ptr);
73   if (tensor_ptr->data_type() == kNumberTypeFloat16) {
74     auto *half_data = static_cast<const float16 *>(tensor_ptr->data_c());
75     MS_EXCEPTION_IF_NULL(half_data);
76     float float_data = half_to_float(half_data[0]);
77     return MakeValue(float_data);
78   } else if (tensor_ptr->data_type() == kNumberTypeFloat32) {
79     auto *tensor_data = static_cast<const float *>(tensor_ptr->data_c());
80     MS_EXCEPTION_IF_NULL(tensor_data);
81     return MakeValue(tensor_data[0]);
82   } else {
83     MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32";
84     return nullptr;
85   }
86 }
87 
CreateBNTrainingReduce(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const88 AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
89                                                         const EquivPtr &equiv) const {
90   MS_EXCEPTION_IF_NULL(func_graph);
91   MS_EXCEPTION_IF_NULL(node);
92   MS_EXCEPTION_IF_NULL(equiv);
93   // Set input to create node
94   std::vector<AnfNodePtr> bn_training_reduce_inputs = {
95     NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)};
96   auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
97   MS_EXCEPTION_IF_NULL(bn_training_reduce);
98   bn_training_reduce->set_scope(node->scope());
99   // Set abstract
100   auto data_input1 = GetAnfNodeByVar(equiv, data_input1_var_);
101   MS_EXCEPTION_IF_NULL(data_input1);
102   auto data_input2 = GetAnfNodeByVar(equiv, data_input2_var_);
103   MS_EXCEPTION_IF_NULL(data_input2);
104   AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()};
105   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
106   bn_training_reduce->set_abstract(abstract_tuple);
107   return bn_training_reduce;
108 }
109 
GetBNTrainingUpdateInputs(const EquivPtr & equiv,const std::vector<AnfNodePtr> & bn_training_reduce_outputs,std::vector<AnfNodePtr> * bn_training_update_inputs) const110 void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv,
111                                                      const std::vector<AnfNodePtr> &bn_training_reduce_outputs,
112                                                      std::vector<AnfNodePtr> *bn_training_update_inputs) const {
113   MS_EXCEPTION_IF_NULL(equiv);
114   MS_EXCEPTION_IF_NULL(bn_training_update_inputs);
115   *bn_training_update_inputs = {
116     NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName)),
117     utils::cast<AnfNodePtr>(GetAnfNodeByVar(equiv, data_input0_var_)),
118     bn_training_reduce_outputs[0],
119     bn_training_reduce_outputs[1],
120     GetAnfNodeByVar(equiv, data_input1_var_),
121     GetAnfNodeByVar(equiv, data_input2_var_),
122     GetAnfNodeByVar(equiv, variable_input0_var_),
123     GetAnfNodeByVar(equiv, variable_input1_var_),
124   };
125 }
126 
GetBNTrainingUpdateAbstractList(const EquivPtr & equiv,const AnfNodePtr & bn,std::vector<AbstractBasePtr> * abstract_list) const127 void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn,
128                                                            std::vector<AbstractBasePtr> *abstract_list) const {
129   MS_EXCEPTION_IF_NULL(equiv);
130   MS_EXCEPTION_IF_NULL(bn);
131   MS_EXCEPTION_IF_NULL(abstract_list);
132   auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
133   MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
134   if (bn_abstract_tuple->elements().size() < kBnOutputNum) {
135     MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is "
136                       << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
137   }
138   auto variable_input0 = GetAnfNodeByVar(equiv, variable_input0_var_);
139   auto variable_input1 = GetAnfNodeByVar(equiv, variable_input1_var_);
140   MS_EXCEPTION_IF_NULL(variable_input0);
141   MS_EXCEPTION_IF_NULL(variable_input1);
142   *abstract_list = {bn_abstract_tuple->elements()[kIndex0], variable_input0->abstract(), variable_input1->abstract(),
143                     bn_abstract_tuple->elements()[kIndex1], bn_abstract_tuple->elements()[kIndex2]};
144 }
145 
CreateBNTrainingUpdate(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::vector<AnfNodePtr> & bn_training_reduce_outputs) const146 AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
147   const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
148   const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
149   MS_EXCEPTION_IF_NULL(func_graph);
150   MS_EXCEPTION_IF_NULL(node);
151   MS_EXCEPTION_IF_NULL(equiv);
152   // Set input
153   std::vector<AnfNodePtr> bn_training_update_inputs;
154   GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs);
155   auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs);
156   MS_EXCEPTION_IF_NULL(bn_training_update);
157   // Set abstract
158   AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_);
159   AbstractBasePtrList abstract_list;
160   GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list);
161   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
162   bn_training_update->set_abstract(abstract_tuple);
163   AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, bn_training_update);
164   ValuePtr factor = GetFactor(equiv);
165   if (factor == nullptr) {
166     return nullptr;
167   }
168   AnfAlgo::SetNodeAttr(kAttrFactor, factor, bn_training_update);
169   AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update);
170   bn_training_update->set_scope(node->scope());
171   return bn_training_update;
172 }
173 
EliminateMonadNodes(const FuncGraphPtr & func_graph,const EquivPtr & equiv) const174 void FusedBatchNormFusion::EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
175   MS_EXCEPTION_IF_NULL(func_graph);
176   MS_EXCEPTION_IF_NULL(equiv);
177   auto manager = func_graph->manager();
178   MS_EXCEPTION_IF_NULL(manager);
179   auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_);
180   MS_EXCEPTION_IF_NULL(assign_sub1);
181   auto users = manager->node_users()[assign_sub1];
182   for (const auto &node_index : users) {
183     const AnfNodePtr &output = node_index.first;
184     MS_EXCEPTION_IF_NULL(output);
185     if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
186       (void)manager->Replace(output, GetAnfNodeByVar(equiv, monad0_var_));
187       break;
188     }
189   }
190 }
191 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const192 const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
193                                                const EquivPtr &equiv) const {
194   MS_EXCEPTION_IF_NULL(func_graph);
195   MS_EXCEPTION_IF_NULL(equiv);
196   MS_EXCEPTION_IF_NULL(node);
197 
198   AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node, equiv);
199   std::vector<AnfNodePtr> bn_training_reduce_outputs;
200   CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum,
201                                  &bn_training_reduce_outputs);
202   AnfNodePtr bn_training_update = CreateBNTrainingUpdate(func_graph, node, equiv, bn_training_reduce_outputs);
203   if (bn_training_update == nullptr) {
204     MS_LOG(DEBUG) << "Create BNTrainingUpdate failed for bn node " << node->DebugString();
205     return nullptr;
206   }
207   std::vector<AnfNodePtr> bn_training_update_outputs;
208   CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update, kBNTrainingUpdateOutputNum,
209                                  &bn_training_update_outputs);
210   if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) {
211     MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is "
212                       << bn_training_update_outputs.size() << " trace: " << trace::DumpSourceLines(node);
213   }
214   // Replace old bn outputs with new outputs
215   std::vector<AnfNodePtr> bn_outputs;
216   GetBNOutput(func_graph, GetAnfNodeByVar(equiv, batch_norm_var_), &bn_outputs);
217   auto manager = func_graph->manager();
218   MS_EXCEPTION_IF_NULL(manager);
219   for (const auto &output : bn_outputs) {
220     MS_EXCEPTION_IF_NULL(output);
221     if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
222       continue;
223     }
224     auto tuple_getitem_cnode = output->cast<CNodePtr>();
225     MS_EXCEPTION_IF_NULL(tuple_getitem_cnode);
226     AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem);
227     MS_EXCEPTION_IF_NULL(index_node);
228     auto value_node = index_node->cast<ValueNodePtr>();
229     MS_EXCEPTION_IF_NULL(value_node);
230     auto value_index = GetValue<int64_t>(value_node->value());
231     if (value_index < 0) {
232       MS_LOG(EXCEPTION) << "Error value index: " << value_index;
233     }
234     auto index = LongToSize(value_index);
235     if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) {
236       (void)manager->Replace(output, bn_training_update_outputs[index]);
237     }
238   }
239   (void)manager->Replace(node, bn_training_update_outputs[0]);
240   EliminateMonadNodes(func_graph, equiv);
241   return nullptr;
242 }
243 
DefinePattern() const244 const BaseRef FusedBatchNormFusion::DefinePattern() const {
245   std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
246   VarPtr index0 = std::make_shared<CondVar>(IsC);
247   VarPtr index1 = std::make_shared<CondVar>(IsC);
248   VarPtr index2 = std::make_shared<CondVar>(IsC);
249   VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
250   VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
251   VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
252   VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
253   VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1});
254   VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2});
255   VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
256   VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
257   VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
258   VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
259   VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
260   return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
261 }
262 
DefinePattern() const263 const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
264   std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
265   VarPtr index0 = std::make_shared<CondVar>(IsC);
266   VarPtr index1 = std::make_shared<CondVar>(IsC);
267   VarPtr index2 = std::make_shared<CondVar>(IsC);
268   VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
269   VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
270   VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
271   VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
272   VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_});
273   VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_});
274   VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1});
275   VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2});
276   VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
277   VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
278   VectorRef cast2 = VectorRef({prim::kPrimCast, mul0});
279   VectorRef cast3 = VectorRef({prim::kPrimCast, mul1});
280   VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, cast2, monad0_var_});
281   VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, cast3, monad1_var_});
282   VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
283   return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
284 }
285 
DefinePattern() const286 const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const {
287   std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
288   VarPtr index0 = std::make_shared<CondVar>(IsC);
289   VarPtr index1 = std::make_shared<CondVar>(IsC);
290   VarPtr index2 = std::make_shared<CondVar>(IsC);
291   VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
292   VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
293   VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
294   VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
295   VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_});
296   VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_});
297   VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1});
298   VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2});
299   VectorRef cast0 = VectorRef({prim::kPrimCast, sub0});
300   VectorRef cast1 = VectorRef({prim::kPrimCast, sub1});
301   VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_});
302   VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_});
303   VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
304   VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
305   VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
306   return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
307 }
308 }  // namespace opt
309 }  // namespace mindspore
310