• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include "base/core_ops.h"
23 #include "utils/utils.h"
24 #include "backend/optimizer/common/helper.h"
25 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "ir/tensor.h"
28 #include "ir/manager.h"
29 #include "backend/kernel_compiler/kernel_build_info.h"
30 #include "backend/kernel_compiler/common_utils.h"
31 #include "runtime/device/kernel_info.h"
32 #include "backend/optimizer/graph_kernel/decrease_transfer_precision.h"
33 namespace mindspore {
34 namespace opt {
35 static const size_t GK_MIN_SIZE = 2;  // 2
36 
ObtainGetItemIndex(const AnfNodePtr & getitem)37 int64_t ObtainGetItemIndex(const AnfNodePtr &getitem) {
38   auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
39   auto value_ptr = GetValueNode(index_node);
40   return GetValue<int64_t>(value_ptr);
41 }
42 
IsPreNodeReduce(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index)43 bool IsPreNodeReduce(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out, size_t index) {
44   auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
45   MS_EXCEPTION_IF_NULL(gk_graph);
46   if (is_tuple_out) {
47     auto tuple_output = gk_graph->output()->cast<CNodePtr>();
48     if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
49       MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
50     }
51     auto input_node = tuple_output->input(index + 1);
52     if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
53       return true;
54     }
55   }
56   return false;
57 }
58 
GetGraphKernelSize(const AnfNodePtr & node)59 size_t GetGraphKernelSize(const AnfNodePtr &node) {
60   auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
61   MS_EXCEPTION_IF_NULL(gk_graph);
62   return gk_graph->GetOrderedCnodes().size();
63 }
64 
IsCandidateNode(const AnfNodePtr & node)65 bool IsCandidateNode(const AnfNodePtr &node) {
66   bool is_gk = AnfAlgo::IsGraphKernel(node);
67   if (is_gk) {
68     auto num = GetGraphKernelSize(node);
69     if (num > GK_MIN_SIZE) {
70       auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
71       auto graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
72       if (graph_name.find("atomic") == std::string::npos) {
73         return true;
74       }
75     }
76   }
77   return false;
78 }
79 
IsAllUserCandidateNode(const AnfNodeIndexSet & users)80 bool IsAllUserCandidateNode(const AnfNodeIndexSet &users) {
81   // check whether all user are graph kernel when more than one users for the in_node
82   bool result = std::all_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
83     return IsCandidateNode(node_index.first);
84   });
85   return result;
86 }
87 
Run(const FuncGraphPtr & func_graph)88 bool DecreaseTransferPrecision::Run(const FuncGraphPtr &func_graph) {
89   auto mng = func_graph->manager();
90   if (mng == nullptr) {
91     mng = Manage(func_graph, true);
92     func_graph->set_manager(mng);
93   }
94   auto users_map = mng->node_users();
95   auto todos = TopoSort(func_graph->get_return());
96   bool changed = false;
97   for (const auto &node : todos) {
98     auto is_candidate = IsCandidateNode(node);
99     if (is_candidate) {
100       auto cnode = node->cast<CNodePtr>();
101       for (size_t index = 1; index < cnode->size(); index++) {
102         auto dtype = AnfAlgo::GetInputDeviceDataType(node, index - 1);
103         if (dtype != kNumberTypeFloat32) {
104           continue;
105         }
106         auto item = cnode->input(index);
107         if (!item->cast<CNodePtr>()) {
108           continue;
109         }
110         auto in_node = item->cast<CNodePtr>();
111         if (IsPrimitive(in_node->input(0), prim::kPrimTupleGetItem)) {
112           auto tuple_node = in_node->input(1);
113           auto tuple_index = ObtainGetItemIndex(in_node);
114           auto has_reduce_output = IsPreNodeReduce(func_graph, tuple_node, true, LongToSize(tuple_index));
115           auto fail_flag = !IsCandidateNode(tuple_node) ||
116                            (users_map[in_node].size() > 1 && IsAllUserCandidateNode(users_map[in_node])) ||
117                            has_reduce_output;
118           if (fail_flag) {
119             continue;
120           }
121           // mutate father
122           (void)Process_Father(func_graph, tuple_node, true, LongToSize(tuple_index));
123           in_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(in_node)));
124           // mutate sons
125           for (auto each_out : users_map[in_node]) {
126             (void)Process_Son(func_graph, each_out.first, IntToSize(each_out.second));
127           }
128         }
129         if (IsCandidateNode(in_node)) {
130           auto fail_flag = !IsAllUserCandidateNode(users_map[in_node]);
131           if (fail_flag) {
132             continue;
133           }
134           // mutate father
135           (void)Process_Father(func_graph, in_node, false, 0);
136           // mutate sons
137           (void)Process_Son(func_graph, cnode, index);
138         }
139       }
140     }
141   }
142   return changed;
143 }
144 
Process_Father(const FuncGraphPtr &,const AnfNodePtr & node,bool is_tuple_out,size_t index)145 bool DecreaseTransferPrecision::Process_Father(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out,
146                                                size_t index) {
147   auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
148   MS_EXCEPTION_IF_NULL(gk_graph);
149   auto mng = gk_graph->manager();
150   MS_EXCEPTION_IF_NULL(mng);
151 
152   // lambda func for cast fp32 to fp16
153   auto func_add_cast_fp16 = [&gk_graph](const AnfNodePtr &old_output) {
154     AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_output};
155     auto cnode = gk_graph->NewCNode(inputs);
156     MS_EXCEPTION_IF_NULL(cnode);
157     gk_graph->AddNode(cnode);
158     cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_output)));
159     cnode->set_scope(old_output->scope());
160     SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat16->type_id())), cnode);
161     cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
162     std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
163     std::vector<TypeId> cnode_input_type = {kNumberTypeFloat32};
164     std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
165     std::vector<TypeId> cnode_output_type = {kNumberTypeFloat16};
166     kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
167     graph_info_builder.SetInputsFormat(cnode_input_format);
168     graph_info_builder.SetInputsDeviceType(cnode_input_type);
169     graph_info_builder.SetOutputsFormat(cnode_output_format);
170     graph_info_builder.SetOutputsDeviceType(cnode_output_type);
171     graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
172     graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
173     graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
174     auto info_1 = graph_info_builder.Build();
175     AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
176     return cnode;
177   };
178 
179   if (!is_tuple_out) {
180     auto old_output = gk_graph->output()->cast<CNodePtr>();
181     MS_EXCEPTION_IF_NULL(old_output);
182     if (AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
183         AnfAlgo::GetInputDeviceDataType(old_output, 0) == kNumberTypeFloat16 &&
184         AnfAlgo::GetOutputDeviceDataType(old_output, 0) == kNumberTypeFloat32) {
185       auto real_output = old_output->input(1);
186       gk_graph->set_output(real_output);
187     } else {
188       auto cnode = func_add_cast_fp16(old_output);
189       gk_graph->set_output(cnode);
190     }
191 
192     // get kernel build info
193     node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(node)));
194     auto gk_builder_info =
195       std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
196     std::vector<TypeId> gk_output_type = {kNumberTypeFloat16};
197     gk_builder_info->SetOutputsDeviceType(gk_output_type);
198     AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
199     return true;
200   } else {
201     // cast for graph kernel with make tuple output
202     auto tuple_output = gk_graph->output()->cast<CNodePtr>();
203     if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
204       MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
205     }
206     auto input_node = tuple_output->input(index + 1);
207     auto cnode = func_add_cast_fp16(input_node);
208     tuple_output->set_input(index + 1, cnode);
209 
210     // Update MakeTuple node abstract
211     AbstractBasePtrList abstract_list;
212     for (size_t i = 1; i < tuple_output->size(); ++i) {
213       (void)abstract_list.emplace_back(tuple_output->input(i)->abstract());
214     }
215     tuple_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
216 
217     // Update Graph Kernel abstract
218     node->set_abstract(tuple_output->abstract());
219 
220     // Update Graph Kernel Build Kernel Info
221     auto old_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
222     auto gk_builder_info = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(old_builder_info);
223     auto origin_outputs_type = old_builder_info->GetAllOutputDeviceTypes();
224     std::vector<TypeId> gk_output_type;
225     for (size_t i = 0; i < origin_outputs_type.size(); ++i) {
226       gk_output_type.push_back(origin_outputs_type[i]);
227     }
228     gk_output_type[index] = kNumberTypeFloat16;
229     gk_builder_info->SetOutputsDeviceType(gk_output_type);
230     AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
231 
232     return true;
233   }
234 }
235 
Process_Son(const FuncGraphPtr &,const AnfNodePtr & node,size_t index)236 bool DecreaseTransferPrecision::Process_Son(const FuncGraphPtr &, const AnfNodePtr &node, size_t index) {
237   auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
238   MS_EXCEPTION_IF_NULL(gk_graph);
239   auto mng = gk_graph->manager();
240   MS_EXCEPTION_IF_NULL(mng);
241   auto old_input = gk_graph->get_inputs()[index - 1];
242   MS_EXCEPTION_IF_NULL(old_input);
243 
244   auto user_nodes = mng->node_users()[old_input];
245   // get kernel build info
246   auto gk_builder_info =
247     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
248   auto ori_input_format = AnfAlgo::GetAllInputDeviceTypes(node);
249   std::vector<TypeId> &new_inputs_type = ori_input_format;
250   new_inputs_type[index - 1] = kNumberTypeFloat16;
251   gk_builder_info->SetInputsDeviceType(new_inputs_type);
252   AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
253   AbstractBasePtr old_abstract = node->abstract()->Clone();
254   node->set_abstract(old_abstract);
255 
256   for (const auto &user : user_nodes) {
257     auto user_node = user.first;
258     if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
259         AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat16) {
260       (void)mng->Replace(user_node, old_input);
261       return true;
262     }
263   }
264 
265   auto tensor_input = node->cast<CNodePtr>()->input(index);
266   AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_input};
267   auto cnode = gk_graph->NewCNode(inputs);
268   gk_graph->AddNode(cnode);
269   cnode->set_abstract(old_input->abstract());
270   cnode->set_scope(old_input->scope());
271   SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat32->type_id())), cnode);
272   MS_EXCEPTION_IF_NULL(cnode);
273   old_input->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_input)));
274   cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
275   std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
276   std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
277   std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
278   std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
279   kernel::KernelBuildInfo::KernelBuildInfoBuilder node_info_builder;
280   node_info_builder.SetInputsFormat(cnode_input_format);
281   node_info_builder.SetInputsDeviceType(cnode_input_type);
282   node_info_builder.SetOutputsFormat(cnode_output_format);
283   node_info_builder.SetOutputsDeviceType(cnode_output_type);
284   node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
285   node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
286   node_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
287   auto info_1 = node_info_builder.Build();
288   AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
289   (void)mng->Replace(old_input, cnode);
290   return true;
291 }
292 }  // namespace opt
293 }  // namespace mindspore
294