• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/raise_reduction_precision.h"
17 
18 #include "base/core_ops.h"
19 #include "utils/utils.h"
20 #include "backend/optimizer/common/helper.h"
21 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "ir/tensor.h"
24 #include "backend/kernel_compiler/kernel_build_info.h"
25 #include "backend/kernel_compiler/common_utils.h"
26 #include "runtime/device/kernel_info.h"
27 
28 namespace mindspore {
29 namespace opt {
IsFp16ReduceSum(const AnfNodePtr & node) const30 bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) const {
31   return IsPrimitiveCNode(node, prim::kPrimReduceSum) && AnfAlgo::GetInputDeviceDataType(node, 0) == kNumberTypeFloat16;
32 }
33 
CreateCast(const AnfNodePtr & input,const TypePtr & dst_type,const std::string & format) const34 AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type,
35                                                const std::string &format) const {
36   auto func_graph = input->func_graph();
37   MS_EXCEPTION_IF_NULL(func_graph);
38   AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), input};
39   auto cnode = CreateCNode(inputs, func_graph, {.format = format, .shape = GetShape(input), .type = dst_type});
40   SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(dst_type->type_id())), cnode);
41   return cnode;
42 }
43 
CreateReduceSum(const AnfNodePtr & node,const AnfNodePtr & input) const44 AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) const {
45   auto cnode = node->cast<CNodePtr>();
46   MS_EXCEPTION_IF_NULL(cnode);
47   cnode->set_input(1, input);
48   cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, GetShape(node)));
49   kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
50   info_builder.SetInputsFormat({AnfAlgo::GetInputFormat(node, 0)});
51   info_builder.SetInputsDeviceType({kFloat32->type_id()});
52   info_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(node, 0)});
53   info_builder.SetOutputsDeviceType({kFloat32->type_id()});
54   info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
55   info_builder.SetKernelType(KernelType::AKG_KERNEL);
56   info_builder.SetFusionType(kernel::FusionType::OPAQUE);
57   AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get());
58   return node;
59 }
60 
ReplaceNode(const AnfNodePtr & reduce_node,const AnfNodePtr & cast_node) const61 void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) const {
62   auto mng = reduce_node->func_graph()->manager();
63   MS_EXCEPTION_IF_NULL(mng);
64   // use a copy of user, since the following `mng->Replace` will change the original users of reduce_node.
65   auto users = mng->node_users()[reduce_node];
66   for (const auto &user : users) {
67     auto user_node = user.first;
68     size_t user_index = IntToSize(user.second);
69     if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
70         AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat32) {
71       if (!(mng->Replace(user_node, reduce_node))) {
72         MS_LOG(ERROR) << "Something happened error, when replacing nodes.";
73       }
74     } else {
75       if (user_node->isa<CNode>()) {
76         user_node->cast<CNodePtr>()->set_input(user_index, cast_node);
77       }
78     }
79   }
80 }
81 
Process(const FuncGraphPtr & func_graph)82 bool RaiseReductionPrecision::Process(const FuncGraphPtr &func_graph) {
83   auto mng = func_graph->manager();
84   if (mng == nullptr) {
85     mng = Manage(func_graph, true);
86     func_graph->set_manager(mng);
87   }
88   auto todos = TopoSort(func_graph->get_return());
89   bool changed = false;
90   for (auto node : todos) {
91     if (IsFp16ReduceSum(node)) {
92       auto cast1 = CreateCast(node->cast<CNodePtr>()->input(1), kFloat32, AnfAlgo::GetInputFormat(node, 0));
93       auto new_reduce = CreateReduceSum(node, cast1);
94       auto cast2 = CreateCast(new_reduce, kFloat16, AnfAlgo::GetOutputFormat(node, 0));
95       ReplaceNode(node, cast2);
96       changed = true;
97     }
98   }
99   if (changed) {
100     mng->RemoveRoots();
101     mng->KeepRoots({func_graph});
102   }
103   return changed;
104 }
105 
Run(const FuncGraphPtr & func_graph)106 bool RaiseReductionPrecision::Run(const FuncGraphPtr &func_graph) {
107   auto mng = func_graph->manager();
108   if (mng == nullptr) {
109     mng = Manage(func_graph, true);
110     func_graph->set_manager(mng);
111   }
112   bool changed = false;
113   auto todos = TopoSort(func_graph->get_return());
114   for (const auto &node : todos) {
115     if (AnfAlgo::IsGraphKernel(node)) {
116       auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
117       MS_ERROR_IF_NULL(sub_func_graph);
118       changed = Process(sub_func_graph) || changed;
119     }
120   }
121   if (changed) {
122     UpdateMng(mng, func_graph);
123   }
124   return changed;
125 }
126 }  // namespace opt
127 }  // namespace mindspore
128