• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 #include <string>
22 
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "ir/primitive.h"
25 #include "utils/utils.h"
26 #include "backend/optimizer/common/helper.h"
27 #include "runtime/device/gpu/kernel_info_setter.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace {
32 const std::vector<int> kOutputIndex{0, 1, 2};
33 constexpr size_t kBNGradOutputNum = 3;
34 constexpr size_t kBNAddReluGradOutputNum = 4;
35 
GetBatchNormOutputs(const FuncGraphPtr & func_graph,const AnfNodePtr & bn,std::vector<AnfNodePtr> * bn_outputs)36 bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
37   MS_EXCEPTION_IF_NULL(func_graph);
38   MS_EXCEPTION_IF_NULL(bn);
39   MS_EXCEPTION_IF_NULL(bn_outputs);
40   auto manager = func_graph->manager();
41   MS_EXCEPTION_IF_NULL(manager);
42   if (manager->node_users().find(bn) == manager->node_users().end()) {
43     return false;
44   }
45   size_t output_num = 0;
46   for (const auto &node_index : manager->node_users()[bn]) {
47     const AnfNodePtr &output = node_index.first;
48     MS_EXCEPTION_IF_NULL(output);
49     if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
50       continue;
51     }
52     auto tuple_getiterm_cnode = output->cast<CNodePtr>();
53     MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
54     auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
55     MS_EXCEPTION_IF_NULL(index_node);
56     auto value_node = index_node->cast<ValueNodePtr>();
57     MS_EXCEPTION_IF_NULL(value_node);
58     int index = static_cast<int>(GetValue<int64_t>(value_node->value()));
59     if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) {
60       return false;
61     }
62     bn_outputs->push_back(output);
63     output_num++;
64   }
65   return output_num == kBNGradOutputNum;
66 }
67 
SetShapeAndType(const CNodePtr & bn_add_relu_grad,const AnfNodePtr & bn_grad,const AnfNodePtr & relu_grad)68 void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad) {
69   // set output shape and dtype
70   std::vector<TypeId> outputs_type;
71   std::vector<std::vector<size_t>> outputs_shape;
72   auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad);
73   for (size_t i = 0; i < output_num; ++i) {
74     outputs_type.push_back(AnfAlgo::GetOutputInferDataType(bn_grad, i));
75     outputs_shape.push_back(AnfAlgo::GetOutputInferShape(bn_grad, i));
76   }
77 
78   outputs_type.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, 0));
79   outputs_shape.push_back(AnfAlgo::GetOutputInferShape(relu_grad, 0));
80   AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, bn_add_relu_grad.get());
81 }
82 
ReplaceOutput(const FuncGraphPtr & graph,const AnfNodePtr & bn_grad,const AnfNodePtr & relu_grad,const CNodePtr & bn_add_relu_grad)83 void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad,
84                    const CNodePtr &bn_add_relu_grad) {
85   // Create outputs
86   std::vector<AnfNodePtr> bn_add_relu_grad_output;
87   CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output);
88   if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) {
89     MS_LOG(EXCEPTION) << "The output size of node " << kBatchNormGradWithAddAndActivation << " must be "
90                       << kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size();
91   }
92 
93   // Get bn outputs
94   std::vector<AnfNodePtr> bn_outputs;
95   if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) {
96     MS_LOG(INFO) << "The " << prim::kPrimBatchNormGrad
97                  << " node should only have output 0, 1 and 2. The node should not be changed";
98     return;
99   }
100 
101   // Replace original output
102   auto manager = graph->manager();
103   MS_EXCEPTION_IF_NULL(manager);
104   sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem);
105   size_t output_index = 0;
106   for (const auto &output : bn_outputs) {
107     (void)manager->Replace(output, bn_add_relu_grad_output[output_index]);
108     output_index++;
109   }
110 
111   manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
112   return;
113 }
114 
PatternCheck(const FuncGraphPtr & graph,const AnfNodePtr & node)115 bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
116   MS_EXCEPTION_IF_NULL(graph);
117   MS_EXCEPTION_IF_NULL(node);
118   auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("format");
119   MS_EXCEPTION_IF_NULL(format_attr);
120   auto format = GetValue<std::string>(format_attr);
121   if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
122     return false;
123   }
124   auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
125   if ((shape.back() % kBNChannelMultipleFactor) != 0) {
126     return false;
127   }
128 
129   auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
130   MS_EXCEPTION_IF_NULL(relu_grad);
131   auto relu_users = GetRealNodeUsedList(graph, relu_grad);
132   if (relu_users->size() != 2) {
133     return false;
134   }
135 
136   // process pattern as Relu(TensorAdd(BN#0, BN#1))
137   auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
138   MS_EXCEPTION_IF_NULL(tuple_getitem);
139   if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
140     return false;
141   }
142   auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
143   if (AnfAlgo::GetCNodeName(forward_node) != kBatchNormWithAddAndActivation) {
144     return false;
145   }
146 
147   return true;
148 }
149 }  // namespace
150 
DefinePattern() const151 const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
152   VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
153   VectorRef batch_norm_grad =
154     VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
155   return batch_norm_grad;
156 }
157 
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const158 const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
159                                                      const EquivPtr &) const {
160   MS_EXCEPTION_IF_NULL(graph);
161   MS_EXCEPTION_IF_NULL(node);
162 
163   if (!PatternCheck(graph, node)) {
164     return nullptr;
165   }
166 
167   auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
168   MS_EXCEPTION_IF_NULL(relu_grad);
169   auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
170   MS_EXCEPTION_IF_NULL(dy);
171   auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);
172   MS_EXCEPTION_IF_NULL(y);
173   auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
174   MS_EXCEPTION_IF_NULL(x);
175   auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 2);
176   MS_EXCEPTION_IF_NULL(scale);
177   auto save_mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
178   MS_EXCEPTION_IF_NULL(save_mean);
179   auto save_var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 4);
180   MS_EXCEPTION_IF_NULL(save_var);
181   auto reserve = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
182   MS_EXCEPTION_IF_NULL(reserve);
183   auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(save_mean), 0);
184   MS_EXCEPTION_IF_NULL(batch_norm);
185   auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
186   MS_EXCEPTION_IF_NULL(bias);
187   auto is_train = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("is_training");
188   MS_EXCEPTION_IF_NULL(is_train);
189   if (!GetValue<bool>(is_train)) {
190     return nullptr;
191   }
192   auto prim = std::make_shared<Primitive>(kBatchNormGradWithAddAndActivation);
193   MS_EXCEPTION_IF_NULL(prim);
194   std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
195   auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs);
196   MS_EXCEPTION_IF_NULL(fused_batch_norm_add_relu_grad);
197   AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
198   SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
199   ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
200   device::gpu::SetKernelInfo(fused_batch_norm_add_relu_grad);
201   return nullptr;
202 }
203 }  // namespace opt
204 }  // namespace mindspore
205