1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include "base/core_ops.h"
23 #include "utils/utils.h"
24 #include "backend/optimizer/common/helper.h"
25 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "ir/tensor.h"
28 #include "ir/manager.h"
29 #include "backend/kernel_compiler/kernel_build_info.h"
30 #include "backend/kernel_compiler/common_utils.h"
31 #include "runtime/device/kernel_info.h"
32 #include "backend/optimizer/graph_kernel/decrease_transfer_precision.h"
33 namespace mindspore {
34 namespace opt {
35 static const size_t GK_MIN_SIZE = 2; // 2
36
ObtainGetItemIndex(const AnfNodePtr & getitem)37 int64_t ObtainGetItemIndex(const AnfNodePtr &getitem) {
38 auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
39 auto value_ptr = GetValueNode(index_node);
40 return GetValue<int64_t>(value_ptr);
41 }
42
IsPreNodeReduce(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index)43 bool IsPreNodeReduce(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out, size_t index) {
44 auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
45 MS_EXCEPTION_IF_NULL(gk_graph);
46 if (is_tuple_out) {
47 auto tuple_output = gk_graph->output()->cast<CNodePtr>();
48 if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
49 MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
50 }
51 auto input_node = tuple_output->input(index + 1);
52 if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
53 return true;
54 }
55 }
56 return false;
57 }
58
GetGraphKernelSize(const AnfNodePtr & node)59 size_t GetGraphKernelSize(const AnfNodePtr &node) {
60 auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
61 MS_EXCEPTION_IF_NULL(gk_graph);
62 return gk_graph->GetOrderedCnodes().size();
63 }
64
IsCandidateNode(const AnfNodePtr & node)65 bool IsCandidateNode(const AnfNodePtr &node) {
66 bool is_gk = AnfAlgo::IsGraphKernel(node);
67 if (is_gk) {
68 auto num = GetGraphKernelSize(node);
69 if (num > GK_MIN_SIZE) {
70 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
71 auto graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
72 if (graph_name.find("atomic") == std::string::npos) {
73 return true;
74 }
75 }
76 }
77 return false;
78 }
79
IsAllUserCandidateNode(const AnfNodeIndexSet & users)80 bool IsAllUserCandidateNode(const AnfNodeIndexSet &users) {
81 // check whether all user are graph kernel when more than one users for the in_node
82 bool result = std::all_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
83 return IsCandidateNode(node_index.first);
84 });
85 return result;
86 }
87
Run(const FuncGraphPtr & func_graph)88 bool DecreaseTransferPrecision::Run(const FuncGraphPtr &func_graph) {
89 auto mng = func_graph->manager();
90 if (mng == nullptr) {
91 mng = Manage(func_graph, true);
92 func_graph->set_manager(mng);
93 }
94 auto users_map = mng->node_users();
95 auto todos = TopoSort(func_graph->get_return());
96 bool changed = false;
97 for (const auto &node : todos) {
98 auto is_candidate = IsCandidateNode(node);
99 if (is_candidate) {
100 auto cnode = node->cast<CNodePtr>();
101 for (size_t index = 1; index < cnode->size(); index++) {
102 auto dtype = AnfAlgo::GetInputDeviceDataType(node, index - 1);
103 if (dtype != kNumberTypeFloat32) {
104 continue;
105 }
106 auto item = cnode->input(index);
107 if (!item->cast<CNodePtr>()) {
108 continue;
109 }
110 auto in_node = item->cast<CNodePtr>();
111 if (IsPrimitive(in_node->input(0), prim::kPrimTupleGetItem)) {
112 auto tuple_node = in_node->input(1);
113 auto tuple_index = ObtainGetItemIndex(in_node);
114 auto has_reduce_output = IsPreNodeReduce(func_graph, tuple_node, true, LongToSize(tuple_index));
115 auto fail_flag = !IsCandidateNode(tuple_node) ||
116 (users_map[in_node].size() > 1 && IsAllUserCandidateNode(users_map[in_node])) ||
117 has_reduce_output;
118 if (fail_flag) {
119 continue;
120 }
121 // mutate father
122 (void)Process_Father(func_graph, tuple_node, true, LongToSize(tuple_index));
123 in_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(in_node)));
124 // mutate sons
125 for (auto each_out : users_map[in_node]) {
126 (void)Process_Son(func_graph, each_out.first, IntToSize(each_out.second));
127 }
128 }
129 if (IsCandidateNode(in_node)) {
130 auto fail_flag = !IsAllUserCandidateNode(users_map[in_node]);
131 if (fail_flag) {
132 continue;
133 }
134 // mutate father
135 (void)Process_Father(func_graph, in_node, false, 0);
136 // mutate sons
137 (void)Process_Son(func_graph, cnode, index);
138 }
139 }
140 }
141 }
142 return changed;
143 }
144
Process_Father(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index)145 bool DecreaseTransferPrecision::Process_Father(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out,
146 size_t index) {
147 auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
148 MS_EXCEPTION_IF_NULL(gk_graph);
149 auto mng = gk_graph->manager();
150 MS_EXCEPTION_IF_NULL(mng);
151
152 // lambda func for cast fp32 to fp16
153 auto func_add_cast_fp16 = [&gk_graph](const AnfNodePtr &old_output) {
154 AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_output};
155 auto cnode = gk_graph->NewCNode(inputs);
156 MS_EXCEPTION_IF_NULL(cnode);
157 gk_graph->AddNode(cnode);
158 cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_output)));
159 cnode->set_scope(old_output->scope());
160 SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat16->type_id())), cnode);
161 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
162 std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
163 std::vector<TypeId> cnode_input_type = {kNumberTypeFloat32};
164 std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
165 std::vector<TypeId> cnode_output_type = {kNumberTypeFloat16};
166 kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
167 graph_info_builder.SetInputsFormat(cnode_input_format);
168 graph_info_builder.SetInputsDeviceType(cnode_input_type);
169 graph_info_builder.SetOutputsFormat(cnode_output_format);
170 graph_info_builder.SetOutputsDeviceType(cnode_output_type);
171 graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
172 graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
173 graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
174 auto info_1 = graph_info_builder.Build();
175 AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
176 return cnode;
177 };
178
179 if (!is_tuple_out) {
180 auto old_output = gk_graph->output()->cast<CNodePtr>();
181 MS_EXCEPTION_IF_NULL(old_output);
182 if (AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
183 AnfAlgo::GetInputDeviceDataType(old_output, 0) == kNumberTypeFloat16 &&
184 AnfAlgo::GetOutputDeviceDataType(old_output, 0) == kNumberTypeFloat32) {
185 auto real_output = old_output->input(1);
186 gk_graph->set_output(real_output);
187 } else {
188 auto cnode = func_add_cast_fp16(old_output);
189 gk_graph->set_output(cnode);
190 }
191
192 // get kernel build info
193 node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(node)));
194 auto gk_builder_info =
195 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
196 std::vector<TypeId> gk_output_type = {kNumberTypeFloat16};
197 gk_builder_info->SetOutputsDeviceType(gk_output_type);
198 AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
199 return true;
200 } else {
201 // cast for graph kernel with make tuple output
202 auto tuple_output = gk_graph->output()->cast<CNodePtr>();
203 if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
204 MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
205 }
206 auto input_node = tuple_output->input(index + 1);
207 auto cnode = func_add_cast_fp16(input_node);
208 tuple_output->set_input(index + 1, cnode);
209
210 // Update MakeTuple node abstract
211 AbstractBasePtrList abstract_list;
212 for (size_t i = 1; i < tuple_output->size(); ++i) {
213 (void)abstract_list.emplace_back(tuple_output->input(i)->abstract());
214 }
215 tuple_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
216
217 // Update Graph Kernel abstract
218 node->set_abstract(tuple_output->abstract());
219
220 // Update Graph Kernel Build Kernel Info
221 auto old_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
222 auto gk_builder_info = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(old_builder_info);
223 auto origin_outputs_type = old_builder_info->GetAllOutputDeviceTypes();
224 std::vector<TypeId> gk_output_type;
225 for (size_t i = 0; i < origin_outputs_type.size(); ++i) {
226 gk_output_type.push_back(origin_outputs_type[i]);
227 }
228 gk_output_type[index] = kNumberTypeFloat16;
229 gk_builder_info->SetOutputsDeviceType(gk_output_type);
230 AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
231
232 return true;
233 }
234 }
235
Process_Son(const FuncGraphPtr &,const AnfNodePtr & node,size_t index)236 bool DecreaseTransferPrecision::Process_Son(const FuncGraphPtr &, const AnfNodePtr &node, size_t index) {
237 auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
238 MS_EXCEPTION_IF_NULL(gk_graph);
239 auto mng = gk_graph->manager();
240 MS_EXCEPTION_IF_NULL(mng);
241 auto old_input = gk_graph->get_inputs()[index - 1];
242 MS_EXCEPTION_IF_NULL(old_input);
243
244 auto user_nodes = mng->node_users()[old_input];
245 // get kernel build info
246 auto gk_builder_info =
247 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
248 auto ori_input_format = AnfAlgo::GetAllInputDeviceTypes(node);
249 std::vector<TypeId> &new_inputs_type = ori_input_format;
250 new_inputs_type[index - 1] = kNumberTypeFloat16;
251 gk_builder_info->SetInputsDeviceType(new_inputs_type);
252 AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
253 AbstractBasePtr old_abstract = node->abstract()->Clone();
254 node->set_abstract(old_abstract);
255
256 for (const auto &user : user_nodes) {
257 auto user_node = user.first;
258 if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
259 AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat16) {
260 (void)mng->Replace(user_node, old_input);
261 return true;
262 }
263 }
264
265 auto tensor_input = node->cast<CNodePtr>()->input(index);
266 AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_input};
267 auto cnode = gk_graph->NewCNode(inputs);
268 gk_graph->AddNode(cnode);
269 cnode->set_abstract(old_input->abstract());
270 cnode->set_scope(old_input->scope());
271 SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat32->type_id())), cnode);
272 MS_EXCEPTION_IF_NULL(cnode);
273 old_input->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_input)));
274 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
275 std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
276 std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
277 std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
278 std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
279 kernel::KernelBuildInfo::KernelBuildInfoBuilder node_info_builder;
280 node_info_builder.SetInputsFormat(cnode_input_format);
281 node_info_builder.SetInputsDeviceType(cnode_input_type);
282 node_info_builder.SetOutputsFormat(cnode_output_format);
283 node_info_builder.SetOutputsDeviceType(cnode_output_type);
284 node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
285 node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
286 node_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
287 auto info_1 = node_info_builder.Build();
288 AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
289 (void)mng->Replace(old_input, cnode);
290 return true;
291 }
292 } // namespace opt
293 } // namespace mindspore
294