1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include "include/common/utils/utils.h"
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "backend/common/graph_kernel/graph_kernel_helper.h"
27 #include "ir/manager.h"
28 #include "kernel/kernel_build_info.h"
29 #include "kernel/framework_utils.h"
30 #include "include/backend/kernel_info.h"
31 #include "backend/common/graph_kernel/decrease_transfer_precision.h"
32
33 namespace mindspore::graphkernel {
34 namespace {
35 constexpr auto kPatternOpaque = "Opaque";
36 }
37
38 static const size_t GK_MIN_SIZE = 2; // 2
39
ObtainGetItemIndex(const AnfNodePtr & getitem)40 int64_t ObtainGetItemIndex(const AnfNodePtr &getitem) {
41 auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
42 auto value_ptr = GetValueNode(index_node);
43 return GetValue<int64_t>(value_ptr);
44 }
45
IsPreNodeReduce(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index)46 bool IsPreNodeReduce(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out, size_t index) {
47 auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
48 MS_EXCEPTION_IF_NULL(gk_graph);
49 if (is_tuple_out) {
50 auto tuple_output = gk_graph->output()->cast<CNodePtr>();
51 if (common::AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
52 MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(tuple_output);
53 }
54 auto input_node = tuple_output->input(index + 1);
55 if (common::AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
56 return true;
57 }
58 }
59 return false;
60 }
61
GetGraphKernelSize(const AnfNodePtr & node)62 size_t GetGraphKernelSize(const AnfNodePtr &node) {
63 auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
64 MS_EXCEPTION_IF_NULL(gk_graph);
65 return gk_graph->GetOrderedCnodes().size();
66 }
67
IsCandidateNode(const AnfNodePtr & node)68 bool IsCandidateNode(const AnfNodePtr &node) {
69 bool is_gk = common::AnfAlgo::IsGraphKernel(node);
70 if (is_gk) {
71 auto num = GetGraphKernelSize(node);
72 if (num > GK_MIN_SIZE) {
73 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
74 auto graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
75 if (graph_name.find("atomic") == std::string::npos) {
76 return true;
77 }
78 }
79 }
80 return false;
81 }
82
IsAllUserCandidateNode(const AnfNodeIndexSet & users)83 bool IsAllUserCandidateNode(const AnfNodeIndexSet &users) {
84 // check whether all user are graph kernel when more than one users for the in_node
85 bool result = std::all_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
86 return IsCandidateNode(node_index.first);
87 });
88 return result;
89 }
90
Run(const FuncGraphPtr & func_graph)91 bool DecreaseTransferPrecision::Run(const FuncGraphPtr &func_graph) {
92 auto mng = func_graph->manager();
93 if (mng == nullptr) {
94 mng = Manage(func_graph, true);
95 func_graph->set_manager(mng);
96 }
97 auto users_map = mng->node_users();
98 auto todos = TopoSort(func_graph->get_return());
99 bool changed = false;
100 for (const auto &node : todos) {
101 auto is_candidate = IsCandidateNode(node);
102 if (is_candidate) {
103 auto cnode = node->cast<CNodePtr>();
104 for (size_t index = 1; index < cnode->size(); index++) {
105 auto dtype = AnfAlgo::GetInputDeviceDataType(node, index - 1);
106 if (dtype != kNumberTypeFloat32) {
107 continue;
108 }
109 auto item = cnode->input(index);
110 if (!item->cast<CNodePtr>()) {
111 continue;
112 }
113 auto in_node = item->cast<CNodePtr>();
114 if (IsPrimitive(in_node->input(0), prim::kPrimTupleGetItem)) {
115 auto tuple_node = in_node->input(1);
116 auto tuple_index = ObtainGetItemIndex(in_node);
117 auto has_reduce_output = IsPreNodeReduce(func_graph, tuple_node, true, LongToSize(tuple_index));
118 auto fail_flag = !IsCandidateNode(tuple_node) ||
119 (users_map[in_node].size() > 1 && IsAllUserCandidateNode(users_map[in_node])) ||
120 has_reduce_output;
121 if (fail_flag) {
122 continue;
123 }
124 // mutate father
125 (void)ProcessFather(func_graph, tuple_node, true, LongToSize(tuple_index));
126 in_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(in_node)));
127 // mutate sons
128 for (auto each_out : users_map[in_node]) {
129 (void)ProcessSon(func_graph, each_out.first, IntToSize(each_out.second));
130 }
131 }
132 if (IsCandidateNode(in_node)) {
133 auto fail_flag = !IsAllUserCandidateNode(users_map[in_node]);
134 if (fail_flag) {
135 continue;
136 }
137 // mutate father
138 (void)ProcessFather(func_graph, in_node, false, 0);
139 // mutate sons
140 (void)ProcessSon(func_graph, cnode, index);
141 }
142 }
143 }
144 }
145 return changed;
146 }
147
ProcessFather(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index) const148 bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out,
149 size_t index) const {
150 auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
151 MS_EXCEPTION_IF_NULL(gk_graph);
152 auto mng = gk_graph->manager();
153 MS_EXCEPTION_IF_NULL(mng);
154
155 // lambda func for cast fp32 to fp16
156 auto func_add_cast_fp16 = [&gk_graph](const AnfNodePtr &old_output) {
157 AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_output};
158 auto cnode = gk_graph->NewCNode(inputs);
159 MS_EXCEPTION_IF_NULL(cnode);
160 gk_graph->AddNode(cnode);
161 cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_output)));
162 cnode->set_scope(old_output->scope());
163 SetNodeAttrSafely(kAttrDstType, kFloat16, cnode);
164 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
165 std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
166 std::vector<TypeId> cnode_input_type = {kNumberTypeFloat32};
167 std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
168 std::vector<TypeId> cnode_output_type = {kNumberTypeFloat16};
169 kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
170 graph_info_builder.SetInputsFormat(cnode_input_format);
171 graph_info_builder.SetInputsDeviceType(cnode_input_type);
172 graph_info_builder.SetOutputsFormat(cnode_output_format);
173 graph_info_builder.SetOutputsDeviceType(cnode_output_type);
174 graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
175 graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
176 graph_info_builder.SetFusionType(kPatternOpaque);
177 auto info_1 = graph_info_builder.Build();
178 AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
179 return cnode;
180 };
181
182 if (!is_tuple_out) {
183 auto old_output = gk_graph->output()->cast<CNodePtr>();
184 MS_EXCEPTION_IF_NULL(old_output);
185 if (common::AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
186 AnfAlgo::GetInputDeviceDataType(old_output, 0) == kNumberTypeFloat16 &&
187 AnfAlgo::GetOutputDeviceDataType(old_output, 0) == kNumberTypeFloat32) {
188 auto real_output = old_output->input(1);
189 gk_graph->set_output(real_output);
190 } else {
191 auto cnode = func_add_cast_fp16(old_output);
192 gk_graph->set_output(cnode);
193 }
194
195 // get kernel build info
196 node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(node)));
197 auto gk_builder_info =
198 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
199 std::vector<TypeId> gk_output_type = {kNumberTypeFloat16};
200 gk_builder_info->SetOutputsDeviceType(gk_output_type);
201 AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
202 return true;
203 } else {
204 // cast for graph kernel with make tuple output
205 auto tuple_output = gk_graph->output()->cast<CNodePtr>();
206 if (common::AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
207 MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(tuple_output);
208 }
209 auto input_node = tuple_output->input(index + 1);
210 auto cnode = func_add_cast_fp16(input_node);
211 tuple_output->set_input(index + 1, cnode);
212
213 // Update MakeTuple node abstract
214 AbstractBasePtrList abstract_list;
215 for (size_t i = 1; i < tuple_output->size(); ++i) {
216 (void)abstract_list.emplace_back(tuple_output->input(i)->abstract());
217 }
218 tuple_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
219
220 // Update Graph Kernel abstract
221 node->set_abstract(tuple_output->abstract());
222
223 // Update Graph Kernel Build Kernel Info
224 auto old_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
225 auto gk_builder_info = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(old_builder_info);
226 auto origin_outputs_type = old_builder_info->GetAllOutputDeviceTypes();
227 std::vector<TypeId> gk_output_type;
228 for (size_t i = 0; i < origin_outputs_type.size(); ++i) {
229 gk_output_type.push_back(origin_outputs_type[i]);
230 }
231 gk_output_type[index] = kNumberTypeFloat16;
232 gk_builder_info->SetOutputsDeviceType(gk_output_type);
233 AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
234
235 return true;
236 }
237 }
238
ProcessSon(const FuncGraphPtr &,const AnfNodePtr & node,size_t index) const239 bool DecreaseTransferPrecision::ProcessSon(const FuncGraphPtr &, const AnfNodePtr &node, size_t index) const {
240 auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
241 MS_EXCEPTION_IF_NULL(gk_graph);
242 auto mng = gk_graph->manager();
243 MS_EXCEPTION_IF_NULL(mng);
244 auto old_input = gk_graph->get_inputs()[index - 1];
245 MS_EXCEPTION_IF_NULL(old_input);
246
247 auto user_nodes = mng->node_users()[old_input];
248 // get kernel build info
249 auto gk_builder_info =
250 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
251 auto ori_input_format = AnfAlgo::GetAllInputDeviceTypes(node);
252 std::vector<TypeId> &new_inputs_type = ori_input_format;
253 new_inputs_type[index - 1] = kNumberTypeFloat16;
254 gk_builder_info->SetInputsDeviceType(new_inputs_type);
255 AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
256 AbstractBasePtr old_abstract = node->abstract()->Clone();
257 node->set_abstract(old_abstract);
258
259 for (const auto &user : user_nodes) {
260 auto user_node = user.first;
261 if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
262 AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat16) {
263 (void)mng->Replace(user_node, old_input);
264 return true;
265 }
266 }
267
268 auto tensor_input = node->cast<CNodePtr>()->input(index);
269 AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_input};
270 auto cnode = gk_graph->NewCNode(inputs);
271 MS_EXCEPTION_IF_NULL(cnode);
272 gk_graph->AddNode(cnode);
273 cnode->set_abstract(old_input->abstract());
274 cnode->set_scope(old_input->scope());
275 SetNodeAttrSafely(kAttrDstType, kFloat32, cnode);
276 old_input->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_input)));
277 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
278 std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
279 std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
280 std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
281 std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
282 kernel::KernelBuildInfo::KernelBuildInfoBuilder node_info_builder;
283 node_info_builder.SetInputsFormat(cnode_input_format);
284 node_info_builder.SetInputsDeviceType(cnode_input_type);
285 node_info_builder.SetOutputsFormat(cnode_output_format);
286 node_info_builder.SetOutputsDeviceType(cnode_output_type);
287 node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
288 node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
289 node_info_builder.SetFusionType(kPatternOpaque);
290 auto info_1 = node_info_builder.Build();
291 AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
292 (void)mng->Replace(old_input, cnode);
293 return true;
294 }
295 } // namespace mindspore::graphkernel
296