• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include "include/common/utils/utils.h"
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "backend/common/graph_kernel/graph_kernel_helper.h"
27 #include "ir/manager.h"
28 #include "kernel/kernel_build_info.h"
29 #include "kernel/framework_utils.h"
30 #include "include/backend/kernel_info.h"
31 #include "backend/common/graph_kernel/decrease_transfer_precision.h"
32 
33 namespace mindspore::graphkernel {
34 namespace {
35 constexpr auto kPatternOpaque = "Opaque";
36 }
37 
38 static const size_t GK_MIN_SIZE = 2;  // 2
39 
ObtainGetItemIndex(const AnfNodePtr & getitem)40 int64_t ObtainGetItemIndex(const AnfNodePtr &getitem) {
41   auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
42   auto value_ptr = GetValueNode(index_node);
43   return GetValue<int64_t>(value_ptr);
44 }
45 
IsPreNodeReduce(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index)46 bool IsPreNodeReduce(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out, size_t index) {
47   auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
48   MS_EXCEPTION_IF_NULL(gk_graph);
49   if (is_tuple_out) {
50     auto tuple_output = gk_graph->output()->cast<CNodePtr>();
51     if (common::AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
52       MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(tuple_output);
53     }
54     auto input_node = tuple_output->input(index + 1);
55     if (common::AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
56       return true;
57     }
58   }
59   return false;
60 }
61 
GetGraphKernelSize(const AnfNodePtr & node)62 size_t GetGraphKernelSize(const AnfNodePtr &node) {
63   auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
64   MS_EXCEPTION_IF_NULL(gk_graph);
65   return gk_graph->GetOrderedCnodes().size();
66 }
67 
IsCandidateNode(const AnfNodePtr & node)68 bool IsCandidateNode(const AnfNodePtr &node) {
69   bool is_gk = common::AnfAlgo::IsGraphKernel(node);
70   if (is_gk) {
71     auto num = GetGraphKernelSize(node);
72     if (num > GK_MIN_SIZE) {
73       auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
74       auto graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
75       if (graph_name.find("atomic") == std::string::npos) {
76         return true;
77       }
78     }
79   }
80   return false;
81 }
82 
IsAllUserCandidateNode(const AnfNodeIndexSet & users)83 bool IsAllUserCandidateNode(const AnfNodeIndexSet &users) {
84   // check whether all user are graph kernel when more than one users for the in_node
85   bool result = std::all_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
86     return IsCandidateNode(node_index.first);
87   });
88   return result;
89 }
90 
Run(const FuncGraphPtr & func_graph)91 bool DecreaseTransferPrecision::Run(const FuncGraphPtr &func_graph) {
92   auto mng = func_graph->manager();
93   if (mng == nullptr) {
94     mng = Manage(func_graph, true);
95     func_graph->set_manager(mng);
96   }
97   auto users_map = mng->node_users();
98   auto todos = TopoSort(func_graph->get_return());
99   bool changed = false;
100   for (const auto &node : todos) {
101     auto is_candidate = IsCandidateNode(node);
102     if (is_candidate) {
103       auto cnode = node->cast<CNodePtr>();
104       for (size_t index = 1; index < cnode->size(); index++) {
105         auto dtype = AnfAlgo::GetInputDeviceDataType(node, index - 1);
106         if (dtype != kNumberTypeFloat32) {
107           continue;
108         }
109         auto item = cnode->input(index);
110         if (!item->cast<CNodePtr>()) {
111           continue;
112         }
113         auto in_node = item->cast<CNodePtr>();
114         if (IsPrimitive(in_node->input(0), prim::kPrimTupleGetItem)) {
115           auto tuple_node = in_node->input(1);
116           auto tuple_index = ObtainGetItemIndex(in_node);
117           auto has_reduce_output = IsPreNodeReduce(func_graph, tuple_node, true, LongToSize(tuple_index));
118           auto fail_flag = !IsCandidateNode(tuple_node) ||
119                            (users_map[in_node].size() > 1 && IsAllUserCandidateNode(users_map[in_node])) ||
120                            has_reduce_output;
121           if (fail_flag) {
122             continue;
123           }
124           // mutate father
125           (void)ProcessFather(func_graph, tuple_node, true, LongToSize(tuple_index));
126           in_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(in_node)));
127           // mutate sons
128           for (auto each_out : users_map[in_node]) {
129             (void)ProcessSon(func_graph, each_out.first, IntToSize(each_out.second));
130           }
131         }
132         if (IsCandidateNode(in_node)) {
133           auto fail_flag = !IsAllUserCandidateNode(users_map[in_node]);
134           if (fail_flag) {
135             continue;
136           }
137           // mutate father
138           (void)ProcessFather(func_graph, in_node, false, 0);
139           // mutate sons
140           (void)ProcessSon(func_graph, cnode, index);
141         }
142       }
143     }
144   }
145   return changed;
146 }
147 
ProcessFather(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index) const148 bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out,
149                                               size_t index) const {
150   auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
151   MS_EXCEPTION_IF_NULL(gk_graph);
152   auto mng = gk_graph->manager();
153   MS_EXCEPTION_IF_NULL(mng);
154 
155   // lambda func for cast fp32 to fp16
156   auto func_add_cast_fp16 = [&gk_graph](const AnfNodePtr &old_output) {
157     AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_output};
158     auto cnode = gk_graph->NewCNode(inputs);
159     MS_EXCEPTION_IF_NULL(cnode);
160     gk_graph->AddNode(cnode);
161     cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_output)));
162     cnode->set_scope(old_output->scope());
163     SetNodeAttrSafely(kAttrDstType, kFloat16, cnode);
164     cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
165     std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
166     std::vector<TypeId> cnode_input_type = {kNumberTypeFloat32};
167     std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
168     std::vector<TypeId> cnode_output_type = {kNumberTypeFloat16};
169     kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
170     graph_info_builder.SetInputsFormat(cnode_input_format);
171     graph_info_builder.SetInputsDeviceType(cnode_input_type);
172     graph_info_builder.SetOutputsFormat(cnode_output_format);
173     graph_info_builder.SetOutputsDeviceType(cnode_output_type);
174     graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
175     graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
176     graph_info_builder.SetFusionType(kPatternOpaque);
177     auto info_1 = graph_info_builder.Build();
178     AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
179     return cnode;
180   };
181 
182   if (!is_tuple_out) {
183     auto old_output = gk_graph->output()->cast<CNodePtr>();
184     MS_EXCEPTION_IF_NULL(old_output);
185     if (common::AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
186         AnfAlgo::GetInputDeviceDataType(old_output, 0) == kNumberTypeFloat16 &&
187         AnfAlgo::GetOutputDeviceDataType(old_output, 0) == kNumberTypeFloat32) {
188       auto real_output = old_output->input(1);
189       gk_graph->set_output(real_output);
190     } else {
191       auto cnode = func_add_cast_fp16(old_output);
192       gk_graph->set_output(cnode);
193     }
194 
195     // get kernel build info
196     node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(node)));
197     auto gk_builder_info =
198       std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
199     std::vector<TypeId> gk_output_type = {kNumberTypeFloat16};
200     gk_builder_info->SetOutputsDeviceType(gk_output_type);
201     AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
202     return true;
203   } else {
204     // cast for graph kernel with make tuple output
205     auto tuple_output = gk_graph->output()->cast<CNodePtr>();
206     if (common::AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
207       MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(tuple_output);
208     }
209     auto input_node = tuple_output->input(index + 1);
210     auto cnode = func_add_cast_fp16(input_node);
211     tuple_output->set_input(index + 1, cnode);
212 
213     // Update MakeTuple node abstract
214     AbstractBasePtrList abstract_list;
215     for (size_t i = 1; i < tuple_output->size(); ++i) {
216       (void)abstract_list.emplace_back(tuple_output->input(i)->abstract());
217     }
218     tuple_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
219 
220     // Update Graph Kernel abstract
221     node->set_abstract(tuple_output->abstract());
222 
223     // Update Graph Kernel Build Kernel Info
224     auto old_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
225     auto gk_builder_info = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(old_builder_info);
226     auto origin_outputs_type = old_builder_info->GetAllOutputDeviceTypes();
227     std::vector<TypeId> gk_output_type;
228     for (size_t i = 0; i < origin_outputs_type.size(); ++i) {
229       gk_output_type.push_back(origin_outputs_type[i]);
230     }
231     gk_output_type[index] = kNumberTypeFloat16;
232     gk_builder_info->SetOutputsDeviceType(gk_output_type);
233     AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
234 
235     return true;
236   }
237 }
238 
ProcessSon(const FuncGraphPtr &,const AnfNodePtr & node,size_t index) const239 bool DecreaseTransferPrecision::ProcessSon(const FuncGraphPtr &, const AnfNodePtr &node, size_t index) const {
240   auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
241   MS_EXCEPTION_IF_NULL(gk_graph);
242   auto mng = gk_graph->manager();
243   MS_EXCEPTION_IF_NULL(mng);
244   auto old_input = gk_graph->get_inputs()[index - 1];
245   MS_EXCEPTION_IF_NULL(old_input);
246 
247   auto user_nodes = mng->node_users()[old_input];
248   // get kernel build info
249   auto gk_builder_info =
250     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
251   auto ori_input_format = AnfAlgo::GetAllInputDeviceTypes(node);
252   std::vector<TypeId> &new_inputs_type = ori_input_format;
253   new_inputs_type[index - 1] = kNumberTypeFloat16;
254   gk_builder_info->SetInputsDeviceType(new_inputs_type);
255   AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
256   AbstractBasePtr old_abstract = node->abstract()->Clone();
257   node->set_abstract(old_abstract);
258 
259   for (const auto &user : user_nodes) {
260     auto user_node = user.first;
261     if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
262         AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat16) {
263       (void)mng->Replace(user_node, old_input);
264       return true;
265     }
266   }
267 
268   auto tensor_input = node->cast<CNodePtr>()->input(index);
269   AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_input};
270   auto cnode = gk_graph->NewCNode(inputs);
271   MS_EXCEPTION_IF_NULL(cnode);
272   gk_graph->AddNode(cnode);
273   cnode->set_abstract(old_input->abstract());
274   cnode->set_scope(old_input->scope());
275   SetNodeAttrSafely(kAttrDstType, kFloat32, cnode);
276   old_input->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_input)));
277   cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
278   std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
279   std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
280   std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
281   std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
282   kernel::KernelBuildInfo::KernelBuildInfoBuilder node_info_builder;
283   node_info_builder.SetInputsFormat(cnode_input_format);
284   node_info_builder.SetInputsDeviceType(cnode_input_type);
285   node_info_builder.SetOutputsFormat(cnode_output_format);
286   node_info_builder.SetOutputsDeviceType(cnode_output_type);
287   node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
288   node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
289   node_info_builder.SetFusionType(kPatternOpaque);
290   auto info_1 = node_info_builder.Build();
291   AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
292   (void)mng->Replace(old_input, cnode);
293   return true;
294 }
295 }  // namespace mindspore::graphkernel
296