1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/ir_fission/bn_split.h"
17
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include <limits>
22
23 #include "utils/utils.h"
24 #include "utils/ms_context.h"
25 #include "backend/optimizer/common/helper.h"
26 #include "runtime/device/kernel_info.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "utils/trace_base.h"
29
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr auto kReduceOpSum = "sum";
34 constexpr auto kDeviceNum = "device_num";
35 constexpr size_t kPositionOffset = 3;
36 constexpr int64_t kFusionNumThreshold = 2;
37
CreateOutputsOfBNTrainingReduce(const FuncGraphPtr & graph,const CNodePtr & bn_cnode,std::vector<AnfNodePtr> * bn_training_reduce_outputs)38 bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
39 std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
40 MS_EXCEPTION_IF_NULL(graph);
41 MS_EXCEPTION_IF_NULL(bn_cnode);
42 if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
43 MS_LOG(INFO) << "BatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString();
44 return false;
45 }
46 std::vector<AnfNodePtr> bn_training_reduce_inputs = {
47 NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
48 bn_training_reduce_inputs.push_back(bn_cnode->input(kIndex1));
49 auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs);
50 MS_EXCEPTION_IF_NULL(bn_training_reduce);
51 auto kernel_info = std::make_shared<device::KernelInfo>();
52 MS_EXCEPTION_IF_NULL(kernel_info);
53 bn_training_reduce->set_kernel_info(kernel_info);
54 std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0);
55 if (bn_shape_i0.size() < kShape2dDims) {
56 MS_LOG(INFO) << "The BatchNorm's first input's shape dims less than " << kShape2dDims;
57 return false;
58 }
59 std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[kDim1]};
60 auto types = {kNumberTypeFloat32, kNumberTypeFloat32};
61 auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape};
62 AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get());
63 bn_training_reduce->set_scope(bn_cnode->scope());
64 AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce);
65
66 CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs);
67 return true;
68 }
69
CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr & graph,const CNodePtr & bn_cnode,const std::vector<AnfNodePtr> & bn_training_reduce_outputs)70 AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
71 const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
72 MS_EXCEPTION_IF_NULL(graph);
73 MS_EXCEPTION_IF_NULL(bn_cnode);
74 CheckCNodeInputSize(bn_cnode, kBnInputTensorNum);
75 if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
76 MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"
77 << " trace: " << trace::DumpSourceLines(bn_cnode);
78 }
79 // the inputs of BNTrainingUpdate are from the outputs of BNTrainingReduce and the inputs of BN
80 std::vector<AnfNodePtr> bn_training_update_inputs = {
81 NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName))};
82 bn_training_update_inputs.push_back(bn_cnode->input(kIndex1));
83 bn_training_update_inputs.push_back(bn_training_reduce_outputs[kIndex0]);
84 bn_training_update_inputs.push_back(bn_training_reduce_outputs[kIndex1]);
85 bn_training_update_inputs.push_back(bn_cnode->input(kIndex2));
86 bn_training_update_inputs.push_back(bn_cnode->input(kIndex3));
87 bn_training_update_inputs.push_back(bn_cnode->input(kIndex4));
88 bn_training_update_inputs.push_back(bn_cnode->input(kIndex5));
89 auto bn_training_update = graph->NewCNode(bn_training_update_inputs);
90 MS_EXCEPTION_IF_NULL(bn_training_update);
91 auto kernel_info = std::make_shared<device::KernelInfo>();
92 MS_EXCEPTION_IF_NULL(kernel_info);
93 bn_training_update->set_kernel_info(kernel_info);
94 bn_training_update->set_abstract(bn_cnode->abstract());
95 bn_training_update->set_scope(bn_cnode->scope());
96 auto factor = AnfAlgo::GetNodeAttr<float>(bn_cnode, kAttrMomentum);
97 AnfAlgo::SetNodeAttr(kAttrFactor, MakeValue<float>(factor), bn_training_update);
98 AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update);
99 AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update);
100 return bn_training_update;
101 }
102
SplitBatchNormForTBE(const FuncGraphPtr & func_graph,const AnfNodePtr & node)103 AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
104 MS_EXCEPTION_IF_NULL(func_graph);
105 MS_EXCEPTION_IF_NULL(node);
106
107 auto cnode = node->cast<CNodePtr>();
108 MS_EXCEPTION_IF_NULL(cnode);
109 if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
110 MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
111 return nullptr;
112 }
113 // Create BNTrainingReduce node and get outputs of BNTrainingReduce
114 std::vector<AnfNodePtr> bn_training_reduce_outputs;
115 if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
116 MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
117 return nullptr;
118 }
119 if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
120 MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"
121 << " trace: " << trace::DumpSourceLines(node);
122 }
123
124 // Create BNTrainingUpdate node
125 return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs);
126 }
127
SyncBNSplitForTBE(const FuncGraphPtr & func_graph,const AnfNodePtr & node)128 AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
129 MS_EXCEPTION_IF_NULL(func_graph);
130 MS_EXCEPTION_IF_NULL(node);
131
132 auto cnode = node->cast<CNodePtr>();
133 MS_EXCEPTION_IF_NULL(cnode);
134 if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
135 MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
136 return nullptr;
137 }
138 // Create BNTrainingReduce node and get outputs of BNTrainingReduce
139 std::vector<AnfNodePtr> bn_training_reduce_outputs;
140 if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
141 MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
142 return nullptr;
143 }
144 if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
145 MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"
146 << " trace: " << trace::DumpSourceLines(node);
147 }
148
149 std::vector<AnfNodePtr> allreduce_mul_outputs;
150 for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) {
151 auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode);
152 allreduce_mul_outputs.emplace_back(allreduce_mul_output);
153 }
154
155 // Create BNTrainingUpdate node
156 return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs);
157 }
158 } // namespace
159
CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr & graph,const CNodePtr & sync_bn_cnode)160 AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) {
161 MS_EXCEPTION_IF_NULL(graph);
162 MS_EXCEPTION_IF_NULL(sync_bn_cnode);
163 if (!AnfAlgo::HasNodeAttr(kDeviceNum, sync_bn_cnode)) {
164 MS_LOG(EXCEPTION) << "The node [" << sync_bn_cnode->DebugString() << "] does not have attr device_num.";
165 }
166 auto device_num = AnfAlgo::GetNodeAttr<int64_t>(sync_bn_cnode, kDeviceNum);
167 if (device_num == 0) {
168 MS_LOG(EXCEPTION) << "The device_num attr of node [" << sync_bn_cnode->DebugString() << "] should not be 0";
169 }
170 MS_LOG(INFO) << "device_num value: " << device_num;
171 const float device_num_reciprocal = 1.0 / device_num;
172
173 std::vector<int64_t> device_num_shape = {};
174 auto device_num_reciprocal_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, device_num_shape);
175 MS_EXCEPTION_IF_NULL(device_num_reciprocal_tensor);
176 auto data_ptr = device_num_reciprocal_tensor->data_c();
177 MS_EXCEPTION_IF_NULL(data_ptr);
178 auto *val = reinterpret_cast<float *>(data_ptr);
179 *val = device_num_reciprocal;
180
181 auto kernel_graph = graph->cast<KernelGraphPtr>();
182 MS_EXCEPTION_IF_NULL(kernel_graph);
183 auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, device_num_shape);
184 auto device_num_reciprocal_value = kernel_graph->NewValueNode(abstract, device_num_reciprocal_tensor);
185 MS_EXCEPTION_IF_NULL(device_num_reciprocal_value);
186 kernel_graph->AddValueNodeToGraph(device_num_reciprocal_value);
187 return device_num_reciprocal_value;
188 }
189
InsertCast(const FuncGraphPtr & graph,const AnfNodePtr & input,const TypeId dst_type)190 AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) {
191 MS_EXCEPTION_IF_NULL(graph);
192 MS_EXCEPTION_IF_NULL(input);
193 if (AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
194 AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
195 AnfAlgo::SetOutputInferTypeAndShape({dst_type}, {AnfAlgo::GetOutputInferShape(input, 0)}, cast.get());
196 AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
197 cast->set_scope(input->scope());
198 return cast;
199 }
200 return input;
201 }
202
CreateAllReduceAndMul(const FuncGraphPtr & graph,const AnfNodePtr & allreduce_input,const CNodePtr & sync_bn_cnode)203 AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
204 const CNodePtr &sync_bn_cnode) {
205 MS_EXCEPTION_IF_NULL(graph);
206 MS_EXCEPTION_IF_NULL(allreduce_input);
207 MS_EXCEPTION_IF_NULL(sync_bn_cnode);
208
209 // Cast input to fp32, this can reduce the number of cast node. Since the input of AllReduce,
210 // BNTrainingReduce/BNTrainingUpdateGrad op only support fp32 output, when inferred output is fp16, it will
211 // insert cast: output_fp32->cast_fp16->allreduce&mul->cast_fp32. Add this cast can eliminate above cast.
212 // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output.
213 AnfNodePtr input_node = InsertCast(graph, allreduce_input, kNumberTypeFloat32);
214
215 // create AllReduce
216 std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), input_node};
217 auto allreduce = graph->NewCNode(allreduce_inputs);
218 MS_EXCEPTION_IF_NULL(allreduce);
219 allreduce->set_abstract(input_node->abstract());
220 allreduce->set_scope(allreduce_input->scope());
221 AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce);
222 AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce);
223 // use SyncBatchNorm's opid as AllReduce's fusion attr
224 auto sync_bn_opname = sync_bn_cnode->fullname_with_scope();
225 auto opid_pos = sync_bn_opname.rfind("-op");
226 if (opid_pos == std::string::npos || opid_pos + kPositionOffset >= sync_bn_opname.size()) {
227 MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] has no opid.";
228 return nullptr;
229 }
230 int64_t opid = std::stol(sync_bn_opname.substr(opid_pos + kPositionOffset));
231 // user defined fusion should be greater than 1
232 if (opid < kFusionNumThreshold) {
233 opid = opid - kFusionNumThreshold + std::numeric_limits<int64_t>::max();
234 }
235 AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(opid), allreduce);
236
237 // create Mul
238 auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode);
239 std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), allreduce,
240 device_num_reciprocal_vnode};
241 auto mul = graph->NewCNode(mul_inputs);
242 MS_EXCEPTION_IF_NULL(mul);
243 mul->set_abstract(input_node->abstract());
244 mul->set_scope(allreduce_input->scope());
245
246 // Cast output to origin datatype to reduce the number of cast node.
247 // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output.
248 return InsertCast(graph, mul, AnfAlgo::GetOutputInferDataType(allreduce_input, 0));
249 }
250
DefinePattern() const251 const BaseRef BnSplit::DefinePattern() const {
252 VarPtr Xs = std::make_shared<SeqVar>();
253 MS_EXCEPTION_IF_NULL(Xs);
254 return VectorRef({prim::kPrimBatchNorm, Xs});
255 }
256
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const257 const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
258 if (!GetBoolAttr(node, kAttrIsTraining)) {
259 MS_LOG(INFO) << "is training should be true if do fusion";
260 return nullptr;
261 }
262 return SplitBatchNormForTBE(func_graph, node);
263 }
264
DefinePattern() const265 const BaseRef SyncBnSplit::DefinePattern() const {
266 VarPtr Xs = std::make_shared<SeqVar>();
267 return VectorRef({prim::kPrimSyncBatchNorm, Xs});
268 }
269
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const270 const AnfNodePtr SyncBnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
271 return SyncBNSplitForTBE(func_graph, node);
272 }
273 } // namespace opt
274 } // namespace mindspore
275