1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/raise_reduction_precision.h"
17
18 #include "base/core_ops.h"
19 #include "utils/utils.h"
20 #include "backend/optimizer/common/helper.h"
21 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
22 #include "backend/session/anf_runtime_algorithm.h"
23 #include "ir/tensor.h"
24 #include "backend/kernel_compiler/kernel_build_info.h"
25 #include "backend/kernel_compiler/common_utils.h"
26 #include "runtime/device/kernel_info.h"
27
28 namespace mindspore {
29 namespace opt {
IsFp16ReduceSum(const AnfNodePtr & node) const30 bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) const {
31 return IsPrimitiveCNode(node, prim::kPrimReduceSum) && AnfAlgo::GetInputDeviceDataType(node, 0) == kNumberTypeFloat16;
32 }
33
CreateCast(const AnfNodePtr & input,const TypePtr & dst_type,const std::string & format) const34 AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type,
35 const std::string &format) const {
36 auto func_graph = input->func_graph();
37 MS_EXCEPTION_IF_NULL(func_graph);
38 AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), input};
39 auto cnode = CreateCNode(inputs, func_graph, {.format = format, .shape = GetShape(input), .type = dst_type});
40 SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(dst_type->type_id())), cnode);
41 return cnode;
42 }
43
CreateReduceSum(const AnfNodePtr & node,const AnfNodePtr & input) const44 AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) const {
45 auto cnode = node->cast<CNodePtr>();
46 MS_EXCEPTION_IF_NULL(cnode);
47 cnode->set_input(1, input);
48 cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, GetShape(node)));
49 kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
50 info_builder.SetInputsFormat({AnfAlgo::GetInputFormat(node, 0)});
51 info_builder.SetInputsDeviceType({kFloat32->type_id()});
52 info_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(node, 0)});
53 info_builder.SetOutputsDeviceType({kFloat32->type_id()});
54 info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
55 info_builder.SetKernelType(KernelType::AKG_KERNEL);
56 info_builder.SetFusionType(kernel::FusionType::OPAQUE);
57 AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get());
58 return node;
59 }
60
ReplaceNode(const AnfNodePtr & reduce_node,const AnfNodePtr & cast_node) const61 void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) const {
62 auto mng = reduce_node->func_graph()->manager();
63 MS_EXCEPTION_IF_NULL(mng);
64 // use a copy of user, since the following `mng->Replace` will change the original users of reduce_node.
65 auto users = mng->node_users()[reduce_node];
66 for (const auto &user : users) {
67 auto user_node = user.first;
68 size_t user_index = IntToSize(user.second);
69 if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
70 AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat32) {
71 if (!(mng->Replace(user_node, reduce_node))) {
72 MS_LOG(ERROR) << "Something happened error, when replacing nodes.";
73 }
74 } else {
75 if (user_node->isa<CNode>()) {
76 user_node->cast<CNodePtr>()->set_input(user_index, cast_node);
77 }
78 }
79 }
80 }
81
Process(const FuncGraphPtr & func_graph)82 bool RaiseReductionPrecision::Process(const FuncGraphPtr &func_graph) {
83 auto mng = func_graph->manager();
84 if (mng == nullptr) {
85 mng = Manage(func_graph, true);
86 func_graph->set_manager(mng);
87 }
88 auto todos = TopoSort(func_graph->get_return());
89 bool changed = false;
90 for (auto node : todos) {
91 if (IsFp16ReduceSum(node)) {
92 auto cast1 = CreateCast(node->cast<CNodePtr>()->input(1), kFloat32, AnfAlgo::GetInputFormat(node, 0));
93 auto new_reduce = CreateReduceSum(node, cast1);
94 auto cast2 = CreateCast(new_reduce, kFloat16, AnfAlgo::GetOutputFormat(node, 0));
95 ReplaceNode(node, cast2);
96 changed = true;
97 }
98 }
99 if (changed) {
100 mng->RemoveRoots();
101 mng->KeepRoots({func_graph});
102 }
103 return changed;
104 }
105
Run(const FuncGraphPtr & func_graph)106 bool RaiseReductionPrecision::Run(const FuncGraphPtr &func_graph) {
107 auto mng = func_graph->manager();
108 if (mng == nullptr) {
109 mng = Manage(func_graph, true);
110 func_graph->set_manager(mng);
111 }
112 bool changed = false;
113 auto todos = TopoSort(func_graph->get_return());
114 for (const auto &node : todos) {
115 if (AnfAlgo::IsGraphKernel(node)) {
116 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
117 MS_ERROR_IF_NULL(sub_func_graph);
118 changed = Process(sub_func_graph) || changed;
119 }
120 }
121 if (changed) {
122 UpdateMng(mng, func_graph);
123 }
124 return changed;
125 }
126 } // namespace opt
127 } // namespace mindspore
128