1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 #include <string>
22
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "ir/primitive.h"
25 #include "utils/utils.h"
26 #include "backend/optimizer/common/helper.h"
27 #include "runtime/device/gpu/kernel_info_setter.h"
28
29 namespace mindspore {
30 namespace opt {
31 namespace {
32 const std::vector<int> kOutputIndex{0, 1, 2};
33 constexpr size_t kBNGradOutputNum = 3;
34 constexpr size_t kBNAddReluGradOutputNum = 4;
35
GetBatchNormOutputs(const FuncGraphPtr & func_graph,const AnfNodePtr & bn,std::vector<AnfNodePtr> * bn_outputs)36 bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
37 MS_EXCEPTION_IF_NULL(func_graph);
38 MS_EXCEPTION_IF_NULL(bn);
39 MS_EXCEPTION_IF_NULL(bn_outputs);
40 auto manager = func_graph->manager();
41 MS_EXCEPTION_IF_NULL(manager);
42 if (manager->node_users().find(bn) == manager->node_users().end()) {
43 return false;
44 }
45 size_t output_num = 0;
46 for (const auto &node_index : manager->node_users()[bn]) {
47 const AnfNodePtr &output = node_index.first;
48 MS_EXCEPTION_IF_NULL(output);
49 if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
50 continue;
51 }
52 auto tuple_getiterm_cnode = output->cast<CNodePtr>();
53 MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
54 auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
55 MS_EXCEPTION_IF_NULL(index_node);
56 auto value_node = index_node->cast<ValueNodePtr>();
57 MS_EXCEPTION_IF_NULL(value_node);
58 int index = static_cast<int>(GetValue<int64_t>(value_node->value()));
59 if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) {
60 return false;
61 }
62 bn_outputs->push_back(output);
63 output_num++;
64 }
65 return output_num == kBNGradOutputNum;
66 }
67
SetShapeAndType(const CNodePtr & bn_add_relu_grad,const AnfNodePtr & bn_grad,const AnfNodePtr & relu_grad)68 void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad) {
69 // set output shape and dtype
70 std::vector<TypeId> outputs_type;
71 std::vector<std::vector<size_t>> outputs_shape;
72 auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad);
73 for (size_t i = 0; i < output_num; ++i) {
74 outputs_type.push_back(AnfAlgo::GetOutputInferDataType(bn_grad, i));
75 outputs_shape.push_back(AnfAlgo::GetOutputInferShape(bn_grad, i));
76 }
77
78 outputs_type.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, 0));
79 outputs_shape.push_back(AnfAlgo::GetOutputInferShape(relu_grad, 0));
80 AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, bn_add_relu_grad.get());
81 }
82
ReplaceOutput(const FuncGraphPtr & graph,const AnfNodePtr & bn_grad,const AnfNodePtr & relu_grad,const CNodePtr & bn_add_relu_grad)83 void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad,
84 const CNodePtr &bn_add_relu_grad) {
85 // Create outputs
86 std::vector<AnfNodePtr> bn_add_relu_grad_output;
87 CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output);
88 if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) {
89 MS_LOG(EXCEPTION) << "The output size of node " << kBatchNormGradWithAddAndActivation << " must be "
90 << kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size();
91 }
92
93 // Get bn outputs
94 std::vector<AnfNodePtr> bn_outputs;
95 if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) {
96 MS_LOG(INFO) << "The " << prim::kPrimBatchNormGrad
97 << " node should only have output 0, 1 and 2. The node should not be changed";
98 return;
99 }
100
101 // Replace original output
102 auto manager = graph->manager();
103 MS_EXCEPTION_IF_NULL(manager);
104 sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem);
105 size_t output_index = 0;
106 for (const auto &output : bn_outputs) {
107 (void)manager->Replace(output, bn_add_relu_grad_output[output_index]);
108 output_index++;
109 }
110
111 manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
112 return;
113 }
114
PatternCheck(const FuncGraphPtr & graph,const AnfNodePtr & node)115 bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
116 MS_EXCEPTION_IF_NULL(graph);
117 MS_EXCEPTION_IF_NULL(node);
118 auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("format");
119 MS_EXCEPTION_IF_NULL(format_attr);
120 auto format = GetValue<std::string>(format_attr);
121 if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
122 return false;
123 }
124 auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
125 if ((shape.back() % kBNChannelMultipleFactor) != 0) {
126 return false;
127 }
128
129 auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
130 MS_EXCEPTION_IF_NULL(relu_grad);
131 auto relu_users = GetRealNodeUsedList(graph, relu_grad);
132 if (relu_users->size() != 2) {
133 return false;
134 }
135
136 // process pattern as Relu(TensorAdd(BN#0, BN#1))
137 auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
138 MS_EXCEPTION_IF_NULL(tuple_getitem);
139 if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
140 return false;
141 }
142 auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
143 if (AnfAlgo::GetCNodeName(forward_node) != kBatchNormWithAddAndActivation) {
144 return false;
145 }
146
147 return true;
148 }
149 } // namespace
150
DefinePattern() const151 const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
152 VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
153 VectorRef batch_norm_grad =
154 VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
155 return batch_norm_grad;
156 }
157
Process(const FuncGraphPtr & graph,const AnfNodePtr & node,const EquivPtr &) const158 const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
159 const EquivPtr &) const {
160 MS_EXCEPTION_IF_NULL(graph);
161 MS_EXCEPTION_IF_NULL(node);
162
163 if (!PatternCheck(graph, node)) {
164 return nullptr;
165 }
166
167 auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
168 MS_EXCEPTION_IF_NULL(relu_grad);
169 auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
170 MS_EXCEPTION_IF_NULL(dy);
171 auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);
172 MS_EXCEPTION_IF_NULL(y);
173 auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
174 MS_EXCEPTION_IF_NULL(x);
175 auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 2);
176 MS_EXCEPTION_IF_NULL(scale);
177 auto save_mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
178 MS_EXCEPTION_IF_NULL(save_mean);
179 auto save_var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 4);
180 MS_EXCEPTION_IF_NULL(save_var);
181 auto reserve = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
182 MS_EXCEPTION_IF_NULL(reserve);
183 auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(save_mean), 0);
184 MS_EXCEPTION_IF_NULL(batch_norm);
185 auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
186 MS_EXCEPTION_IF_NULL(bias);
187 auto is_train = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("is_training");
188 MS_EXCEPTION_IF_NULL(is_train);
189 if (!GetValue<bool>(is_train)) {
190 return nullptr;
191 }
192 auto prim = std::make_shared<Primitive>(kBatchNormGradWithAddAndActivation);
193 MS_EXCEPTION_IF_NULL(prim);
194 std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
195 auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs);
196 MS_EXCEPTION_IF_NULL(fused_batch_norm_add_relu_grad);
197 AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
198 SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
199 ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
200 device::gpu::SetKernelInfo(fused_batch_norm_add_relu_grad);
201 return nullptr;
202 }
203 } // namespace opt
204 } // namespace mindspore
205