1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h"
17 #include <memory>
18 #include <algorithm>
19 #include "backend/optimizer/common/helper.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "utils/utils.h"
22 #include "utils/trace_base.h"
23 #include "runtime/device/ascend/lic_manager.h"
24
25 namespace mindspore {
26 namespace opt {
27 namespace {
28 constexpr size_t kReplaceOutputIndex0 = 3;
29 constexpr size_t kReplaceOutputIndex1 = 4;
IsC(const BaseRef & n)30 bool IsC(const BaseRef &n) {
31 if (utils::isa<AnfNodePtr>(n)) {
32 AnfNodePtr in = utils::cast<AnfNodePtr>(n);
33 MS_EXCEPTION_IF_NULL(in);
34 return in->isa<ValueNode>();
35 }
36 return false;
37 }
38
GetBNOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & bn,std::vector<AnfNodePtr> * bn_outputs)39 void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
40 MS_EXCEPTION_IF_NULL(func_graph);
41 MS_EXCEPTION_IF_NULL(bn);
42 MS_EXCEPTION_IF_NULL(bn_outputs);
43 auto manager = func_graph->manager();
44 MS_EXCEPTION_IF_NULL(manager);
45 if (manager->node_users().find(bn) == manager->node_users().end()) {
46 MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"
47 << " trace: " << trace::DumpSourceLines(bn);
48 }
49 for (const auto &node_index : manager->node_users()[bn]) {
50 const AnfNodePtr &output = node_index.first;
51 MS_EXCEPTION_IF_NULL(output);
52 bn_outputs->push_back(output);
53 }
54 }
55 } // namespace
56
GetFactor(const EquivPtr & equiv) const57 ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const {
58 MS_EXCEPTION_IF_NULL(equiv);
59 auto constant_input = GetAnfNodeByVar(equiv, constant_input0_var_);
60 MS_EXCEPTION_IF_NULL(constant_input);
61 if (!constant_input->isa<ValueNode>()) {
62 return nullptr;
63 }
64 auto value_node = constant_input->cast<ValueNodePtr>();
65 MS_EXCEPTION_IF_NULL(value_node);
66 auto value = value_node->value();
67 MS_EXCEPTION_IF_NULL(value);
68 if (!value->isa<tensor::Tensor>()) {
69 return nullptr;
70 }
71 auto tensor_ptr = value->cast<tensor::TensorPtr>();
72 MS_EXCEPTION_IF_NULL(tensor_ptr);
73 if (tensor_ptr->data_type() == kNumberTypeFloat16) {
74 auto *half_data = static_cast<const float16 *>(tensor_ptr->data_c());
75 MS_EXCEPTION_IF_NULL(half_data);
76 float float_data = half_to_float(half_data[0]);
77 return MakeValue(float_data);
78 } else if (tensor_ptr->data_type() == kNumberTypeFloat32) {
79 auto *tensor_data = static_cast<const float *>(tensor_ptr->data_c());
80 MS_EXCEPTION_IF_NULL(tensor_data);
81 return MakeValue(tensor_data[0]);
82 } else {
83 MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32";
84 return nullptr;
85 }
86 }
87
CreateBNTrainingReduce(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const88 AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
89 const EquivPtr &equiv) const {
90 MS_EXCEPTION_IF_NULL(func_graph);
91 MS_EXCEPTION_IF_NULL(node);
92 MS_EXCEPTION_IF_NULL(equiv);
93 // Set input to create node
94 std::vector<AnfNodePtr> bn_training_reduce_inputs = {
95 NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)};
96 auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
97 MS_EXCEPTION_IF_NULL(bn_training_reduce);
98 bn_training_reduce->set_scope(node->scope());
99 // Set abstract
100 auto data_input1 = GetAnfNodeByVar(equiv, data_input1_var_);
101 MS_EXCEPTION_IF_NULL(data_input1);
102 auto data_input2 = GetAnfNodeByVar(equiv, data_input2_var_);
103 MS_EXCEPTION_IF_NULL(data_input2);
104 AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()};
105 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
106 bn_training_reduce->set_abstract(abstract_tuple);
107 return bn_training_reduce;
108 }
109
GetBNTrainingUpdateInputs(const EquivPtr & equiv,const std::vector<AnfNodePtr> & bn_training_reduce_outputs,std::vector<AnfNodePtr> * bn_training_update_inputs) const110 void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv,
111 const std::vector<AnfNodePtr> &bn_training_reduce_outputs,
112 std::vector<AnfNodePtr> *bn_training_update_inputs) const {
113 MS_EXCEPTION_IF_NULL(equiv);
114 MS_EXCEPTION_IF_NULL(bn_training_update_inputs);
115 *bn_training_update_inputs = {
116 NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName)),
117 utils::cast<AnfNodePtr>(GetAnfNodeByVar(equiv, data_input0_var_)),
118 bn_training_reduce_outputs[0],
119 bn_training_reduce_outputs[1],
120 GetAnfNodeByVar(equiv, data_input1_var_),
121 GetAnfNodeByVar(equiv, data_input2_var_),
122 GetAnfNodeByVar(equiv, variable_input0_var_),
123 GetAnfNodeByVar(equiv, variable_input1_var_),
124 };
125 }
126
GetBNTrainingUpdateAbstractList(const EquivPtr & equiv,const AnfNodePtr & bn,std::vector<AbstractBasePtr> * abstract_list) const127 void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn,
128 std::vector<AbstractBasePtr> *abstract_list) const {
129 MS_EXCEPTION_IF_NULL(equiv);
130 MS_EXCEPTION_IF_NULL(bn);
131 MS_EXCEPTION_IF_NULL(abstract_list);
132 auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
133 MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
134 if (bn_abstract_tuple->elements().size() < kBnOutputNum) {
135 MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is "
136 << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
137 }
138 auto variable_input0 = GetAnfNodeByVar(equiv, variable_input0_var_);
139 auto variable_input1 = GetAnfNodeByVar(equiv, variable_input1_var_);
140 MS_EXCEPTION_IF_NULL(variable_input0);
141 MS_EXCEPTION_IF_NULL(variable_input1);
142 *abstract_list = {bn_abstract_tuple->elements()[kIndex0], variable_input0->abstract(), variable_input1->abstract(),
143 bn_abstract_tuple->elements()[kIndex1], bn_abstract_tuple->elements()[kIndex2]};
144 }
145
CreateBNTrainingUpdate(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::vector<AnfNodePtr> & bn_training_reduce_outputs) const146 AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
147 const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
148 const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const {
149 MS_EXCEPTION_IF_NULL(func_graph);
150 MS_EXCEPTION_IF_NULL(node);
151 MS_EXCEPTION_IF_NULL(equiv);
152 // Set input
153 std::vector<AnfNodePtr> bn_training_update_inputs;
154 GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs);
155 auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs);
156 MS_EXCEPTION_IF_NULL(bn_training_update);
157 // Set abstract
158 AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_);
159 AbstractBasePtrList abstract_list;
160 GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list);
161 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
162 bn_training_update->set_abstract(abstract_tuple);
163 AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, bn_training_update);
164 ValuePtr factor = GetFactor(equiv);
165 if (factor == nullptr) {
166 return nullptr;
167 }
168 AnfAlgo::SetNodeAttr(kAttrFactor, factor, bn_training_update);
169 AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update);
170 bn_training_update->set_scope(node->scope());
171 return bn_training_update;
172 }
173
EliminateMonadNodes(const FuncGraphPtr & func_graph,const EquivPtr & equiv) const174 void FusedBatchNormFusion::EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
175 MS_EXCEPTION_IF_NULL(func_graph);
176 MS_EXCEPTION_IF_NULL(equiv);
177 auto manager = func_graph->manager();
178 MS_EXCEPTION_IF_NULL(manager);
179 auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_);
180 MS_EXCEPTION_IF_NULL(assign_sub1);
181 auto users = manager->node_users()[assign_sub1];
182 for (const auto &node_index : users) {
183 const AnfNodePtr &output = node_index.first;
184 MS_EXCEPTION_IF_NULL(output);
185 if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
186 (void)manager->Replace(output, GetAnfNodeByVar(equiv, monad0_var_));
187 break;
188 }
189 }
190 }
191
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const192 const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
193 const EquivPtr &equiv) const {
194 MS_EXCEPTION_IF_NULL(func_graph);
195 MS_EXCEPTION_IF_NULL(equiv);
196 MS_EXCEPTION_IF_NULL(node);
197
198 AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node, equiv);
199 std::vector<AnfNodePtr> bn_training_reduce_outputs;
200 CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum,
201 &bn_training_reduce_outputs);
202 AnfNodePtr bn_training_update = CreateBNTrainingUpdate(func_graph, node, equiv, bn_training_reduce_outputs);
203 if (bn_training_update == nullptr) {
204 MS_LOG(DEBUG) << "Create BNTrainingUpdate failed for bn node " << node->DebugString();
205 return nullptr;
206 }
207 std::vector<AnfNodePtr> bn_training_update_outputs;
208 CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update, kBNTrainingUpdateOutputNum,
209 &bn_training_update_outputs);
210 if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) {
211 MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is "
212 << bn_training_update_outputs.size() << " trace: " << trace::DumpSourceLines(node);
213 }
214 // Replace old bn outputs with new outputs
215 std::vector<AnfNodePtr> bn_outputs;
216 GetBNOutput(func_graph, GetAnfNodeByVar(equiv, batch_norm_var_), &bn_outputs);
217 auto manager = func_graph->manager();
218 MS_EXCEPTION_IF_NULL(manager);
219 for (const auto &output : bn_outputs) {
220 MS_EXCEPTION_IF_NULL(output);
221 if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
222 continue;
223 }
224 auto tuple_getitem_cnode = output->cast<CNodePtr>();
225 MS_EXCEPTION_IF_NULL(tuple_getitem_cnode);
226 AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem);
227 MS_EXCEPTION_IF_NULL(index_node);
228 auto value_node = index_node->cast<ValueNodePtr>();
229 MS_EXCEPTION_IF_NULL(value_node);
230 auto value_index = GetValue<int64_t>(value_node->value());
231 if (value_index < 0) {
232 MS_LOG(EXCEPTION) << "Error value index: " << value_index;
233 }
234 auto index = LongToSize(value_index);
235 if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) {
236 (void)manager->Replace(output, bn_training_update_outputs[index]);
237 }
238 }
239 (void)manager->Replace(node, bn_training_update_outputs[0]);
240 EliminateMonadNodes(func_graph, equiv);
241 return nullptr;
242 }
243
DefinePattern() const244 const BaseRef FusedBatchNormFusion::DefinePattern() const {
245 std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
246 VarPtr index0 = std::make_shared<CondVar>(IsC);
247 VarPtr index1 = std::make_shared<CondVar>(IsC);
248 VarPtr index2 = std::make_shared<CondVar>(IsC);
249 VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
250 VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
251 VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
252 VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
253 VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1});
254 VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2});
255 VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
256 VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
257 VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
258 VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
259 VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
260 return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
261 }
262
DefinePattern() const263 const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
264 std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
265 VarPtr index0 = std::make_shared<CondVar>(IsC);
266 VarPtr index1 = std::make_shared<CondVar>(IsC);
267 VarPtr index2 = std::make_shared<CondVar>(IsC);
268 VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
269 VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
270 VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
271 VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
272 VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_});
273 VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_});
274 VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1});
275 VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2});
276 VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
277 VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
278 VectorRef cast2 = VectorRef({prim::kPrimCast, mul0});
279 VectorRef cast3 = VectorRef({prim::kPrimCast, mul1});
280 VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, cast2, monad0_var_});
281 VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, cast3, monad1_var_});
282 VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
283 return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
284 }
285
DefinePattern() const286 const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const {
287 std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
288 VarPtr index0 = std::make_shared<CondVar>(IsC);
289 VarPtr index1 = std::make_shared<CondVar>(IsC);
290 VarPtr index2 = std::make_shared<CondVar>(IsC);
291 VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
292 VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
293 VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
294 VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
295 VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_});
296 VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_});
297 VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1});
298 VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2});
299 VectorRef cast0 = VectorRef({prim::kPrimCast, sub0});
300 VectorRef cast1 = VectorRef({prim::kPrimCast, sub1});
301 VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_});
302 VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_});
303 VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
304 VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
305 VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
306 return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
307 }
308 } // namespace opt
309 } // namespace mindspore
310