• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/ir_fission/bn_split.h"
17 
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include <limits>
22 
23 #include "utils/utils.h"
24 #include "utils/ms_context.h"
25 #include "backend/optimizer/common/helper.h"
26 #include "runtime/device/kernel_info.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "utils/trace_base.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr auto kReduceOpSum = "sum";
34 constexpr auto kDeviceNum = "device_num";
35 constexpr size_t kPositionOffset = 3;
36 constexpr int64_t kFusionNumThreshold = 2;
37 
CreateOutputsOfBNTrainingReduce(const FuncGraphPtr & graph,const CNodePtr & bn_cnode,std::vector<AnfNodePtr> * bn_training_reduce_outputs)38 bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
39                                      std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
40   MS_EXCEPTION_IF_NULL(graph);
41   MS_EXCEPTION_IF_NULL(bn_cnode);
42   if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
43     MS_LOG(INFO) << "BatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString();
44     return false;
45   }
46   std::vector<AnfNodePtr> bn_training_reduce_inputs = {
47     NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
48   bn_training_reduce_inputs.push_back(bn_cnode->input(kIndex1));
49   auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs);
50   MS_EXCEPTION_IF_NULL(bn_training_reduce);
51   auto kernel_info = std::make_shared<device::KernelInfo>();
52   MS_EXCEPTION_IF_NULL(kernel_info);
53   bn_training_reduce->set_kernel_info(kernel_info);
54   std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0);
55   if (bn_shape_i0.size() < kShape2dDims) {
56     MS_LOG(INFO) << "The BatchNorm's first input's shape dims less than " << kShape2dDims;
57     return false;
58   }
59   std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[kDim1]};
60   auto types = {kNumberTypeFloat32, kNumberTypeFloat32};
61   auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape};
62   AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get());
63   bn_training_reduce->set_scope(bn_cnode->scope());
64   AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce);
65 
66   CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs);
67   return true;
68 }
69 
CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr & graph,const CNodePtr & bn_cnode,const std::vector<AnfNodePtr> & bn_training_reduce_outputs)70 AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
71                                            const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
72   MS_EXCEPTION_IF_NULL(graph);
73   MS_EXCEPTION_IF_NULL(bn_cnode);
74   CheckCNodeInputSize(bn_cnode, kBnInputTensorNum);
75   if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
76     MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"
77                       << " trace: " << trace::DumpSourceLines(bn_cnode);
78   }
79   // the inputs of BNTrainingUpdate are from the outputs of BNTrainingReduce and the inputs of BN
80   std::vector<AnfNodePtr> bn_training_update_inputs = {
81     NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName))};
82   bn_training_update_inputs.push_back(bn_cnode->input(kIndex1));
83   bn_training_update_inputs.push_back(bn_training_reduce_outputs[kIndex0]);
84   bn_training_update_inputs.push_back(bn_training_reduce_outputs[kIndex1]);
85   bn_training_update_inputs.push_back(bn_cnode->input(kIndex2));
86   bn_training_update_inputs.push_back(bn_cnode->input(kIndex3));
87   bn_training_update_inputs.push_back(bn_cnode->input(kIndex4));
88   bn_training_update_inputs.push_back(bn_cnode->input(kIndex5));
89   auto bn_training_update = graph->NewCNode(bn_training_update_inputs);
90   MS_EXCEPTION_IF_NULL(bn_training_update);
91   auto kernel_info = std::make_shared<device::KernelInfo>();
92   MS_EXCEPTION_IF_NULL(kernel_info);
93   bn_training_update->set_kernel_info(kernel_info);
94   bn_training_update->set_abstract(bn_cnode->abstract());
95   bn_training_update->set_scope(bn_cnode->scope());
96   auto factor = AnfAlgo::GetNodeAttr<float>(bn_cnode, kAttrMomentum);
97   AnfAlgo::SetNodeAttr(kAttrFactor, MakeValue<float>(factor), bn_training_update);
98   AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update);
99   AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update);
100   return bn_training_update;
101 }
102 
SplitBatchNormForTBE(const FuncGraphPtr & func_graph,const AnfNodePtr & node)103 AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
104   MS_EXCEPTION_IF_NULL(func_graph);
105   MS_EXCEPTION_IF_NULL(node);
106 
107   auto cnode = node->cast<CNodePtr>();
108   MS_EXCEPTION_IF_NULL(cnode);
109   if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
110     MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
111     return nullptr;
112   }
113   // Create BNTrainingReduce node and get outputs of BNTrainingReduce
114   std::vector<AnfNodePtr> bn_training_reduce_outputs;
115   if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
116     MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
117     return nullptr;
118   }
119   if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
120     MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"
121                       << " trace: " << trace::DumpSourceLines(node);
122   }
123 
124   // Create BNTrainingUpdate node
125   return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs);
126 }
127 
SyncBNSplitForTBE(const FuncGraphPtr & func_graph,const AnfNodePtr & node)128 AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
129   MS_EXCEPTION_IF_NULL(func_graph);
130   MS_EXCEPTION_IF_NULL(node);
131 
132   auto cnode = node->cast<CNodePtr>();
133   MS_EXCEPTION_IF_NULL(cnode);
134   if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
135     MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
136     return nullptr;
137   }
138   // Create BNTrainingReduce node and get outputs of BNTrainingReduce
139   std::vector<AnfNodePtr> bn_training_reduce_outputs;
140   if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
141     MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
142     return nullptr;
143   }
144   if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
145     MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"
146                       << " trace: " << trace::DumpSourceLines(node);
147   }
148 
149   std::vector<AnfNodePtr> allreduce_mul_outputs;
150   for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) {
151     auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode);
152     allreduce_mul_outputs.emplace_back(allreduce_mul_output);
153   }
154 
155   // Create BNTrainingUpdate node
156   return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs);
157 }
158 }  // namespace
159 
CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr & graph,const CNodePtr & sync_bn_cnode)160 AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) {
161   MS_EXCEPTION_IF_NULL(graph);
162   MS_EXCEPTION_IF_NULL(sync_bn_cnode);
163   if (!AnfAlgo::HasNodeAttr(kDeviceNum, sync_bn_cnode)) {
164     MS_LOG(EXCEPTION) << "The node [" << sync_bn_cnode->DebugString() << "] does not have attr device_num.";
165   }
166   auto device_num = AnfAlgo::GetNodeAttr<int64_t>(sync_bn_cnode, kDeviceNum);
167   if (device_num == 0) {
168     MS_LOG(EXCEPTION) << "The device_num attr of node [" << sync_bn_cnode->DebugString() << "] should not be 0";
169   }
170   MS_LOG(INFO) << "device_num value: " << device_num;
171   const float device_num_reciprocal = 1.0 / device_num;
172 
173   std::vector<int64_t> device_num_shape = {};
174   auto device_num_reciprocal_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, device_num_shape);
175   MS_EXCEPTION_IF_NULL(device_num_reciprocal_tensor);
176   auto data_ptr = device_num_reciprocal_tensor->data_c();
177   MS_EXCEPTION_IF_NULL(data_ptr);
178   auto *val = reinterpret_cast<float *>(data_ptr);
179   *val = device_num_reciprocal;
180 
181   auto kernel_graph = graph->cast<KernelGraphPtr>();
182   MS_EXCEPTION_IF_NULL(kernel_graph);
183   auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, device_num_shape);
184   auto device_num_reciprocal_value = kernel_graph->NewValueNode(abstract, device_num_reciprocal_tensor);
185   MS_EXCEPTION_IF_NULL(device_num_reciprocal_value);
186   kernel_graph->AddValueNodeToGraph(device_num_reciprocal_value);
187   return device_num_reciprocal_value;
188 }
189 
InsertCast(const FuncGraphPtr & graph,const AnfNodePtr & input,const TypeId dst_type)190 AnfNodePtr InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) {
191   MS_EXCEPTION_IF_NULL(graph);
192   MS_EXCEPTION_IF_NULL(input);
193   if (AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) {
194     AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(kCastOpName)), input});
195     AnfAlgo::SetOutputInferTypeAndShape({dst_type}, {AnfAlgo::GetOutputInferShape(input, 0)}, cast.get());
196     AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
197     cast->set_scope(input->scope());
198     return cast;
199   }
200   return input;
201 }
202 
CreateAllReduceAndMul(const FuncGraphPtr & graph,const AnfNodePtr & allreduce_input,const CNodePtr & sync_bn_cnode)203 AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
204                                  const CNodePtr &sync_bn_cnode) {
205   MS_EXCEPTION_IF_NULL(graph);
206   MS_EXCEPTION_IF_NULL(allreduce_input);
207   MS_EXCEPTION_IF_NULL(sync_bn_cnode);
208 
209   // Cast input to fp32, this can reduce the number of cast node. Since the input of AllReduce,
210   // BNTrainingReduce/BNTrainingUpdateGrad op only support fp32 output, when inferred output is fp16, it will
211   // insert cast: output_fp32->cast_fp16->allreduce&mul->cast_fp32. Add this cast can eliminate above cast.
212   // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output.
213   AnfNodePtr input_node = InsertCast(graph, allreduce_input, kNumberTypeFloat32);
214 
215   // create AllReduce
216   std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), input_node};
217   auto allreduce = graph->NewCNode(allreduce_inputs);
218   MS_EXCEPTION_IF_NULL(allreduce);
219   allreduce->set_abstract(input_node->abstract());
220   allreduce->set_scope(allreduce_input->scope());
221   AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce);
222   AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce);
223   // use SyncBatchNorm's opid as AllReduce's fusion attr
224   auto sync_bn_opname = sync_bn_cnode->fullname_with_scope();
225   auto opid_pos = sync_bn_opname.rfind("-op");
226   if (opid_pos == std::string::npos || opid_pos + kPositionOffset >= sync_bn_opname.size()) {
227     MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] has no opid.";
228     return nullptr;
229   }
230   int64_t opid = std::stol(sync_bn_opname.substr(opid_pos + kPositionOffset));
231   // user defined fusion should be greater than 1
232   if (opid < kFusionNumThreshold) {
233     opid = opid - kFusionNumThreshold + std::numeric_limits<int64_t>::max();
234   }
235   AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(opid), allreduce);
236 
237   // create Mul
238   auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode);
239   std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), allreduce,
240                                         device_num_reciprocal_vnode};
241   auto mul = graph->NewCNode(mul_inputs);
242   MS_EXCEPTION_IF_NULL(mul);
243   mul->set_abstract(input_node->abstract());
244   mul->set_scope(allreduce_input->scope());
245 
246   // Cast output to origin datatype to reduce the number of cast node.
247   // Should be removed if BNTrainingReduce/BNTrainingUpdateGrad op support fp16 output.
248   return InsertCast(graph, mul, AnfAlgo::GetOutputInferDataType(allreduce_input, 0));
249 }
250 
DefinePattern() const251 const BaseRef BnSplit::DefinePattern() const {
252   VarPtr Xs = std::make_shared<SeqVar>();
253   MS_EXCEPTION_IF_NULL(Xs);
254   return VectorRef({prim::kPrimBatchNorm, Xs});
255 }
256 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const257 const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
258   if (!GetBoolAttr(node, kAttrIsTraining)) {
259     MS_LOG(INFO) << "is training should be true if do fusion";
260     return nullptr;
261   }
262   return SplitBatchNormForTBE(func_graph, node);
263 }
264 
DefinePattern() const265 const BaseRef SyncBnSplit::DefinePattern() const {
266   VarPtr Xs = std::make_shared<SeqVar>();
267   return VectorRef({prim::kPrimSyncBatchNorm, Xs});
268 }
269 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const270 const AnfNodePtr SyncBnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
271   return SyncBNSplitForTBE(func_graph, node);
272 }
273 }  // namespace opt
274 }  // namespace mindspore
275