1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/raise_reduction_precision.h"
17 #include "mindspore/core/ops/math_ops.h"
18 #include "mindspore/core/ops/array_ops.h"
19 #include "include/common/utils/utils.h"
20 #include "include/backend/optimizer/helper.h"
21 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
22 #include "backend/common/graph_kernel/graph_kernel_helper.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "kernel/kernel_build_info.h"
26 #include "include/backend/kernel_info.h"
27
28 namespace mindspore::graphkernel {
29 namespace {
30 constexpr auto kPatternOpaque = "Opaque";
31 }
32
IsFp16ReduceSum(const AnfNodePtr & node) const33 bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) const {
34 return IsPrimitiveCNode(node, prim::kPrimReduceSum) && AnfAlgo::GetInputDeviceDataType(node, 0) == kNumberTypeFloat16;
35 }
36
CreateCast(const AnfNodePtr & input,const TypePtr & dst_type,const std::string & format) const37 AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type,
38 const std::string &format) const {
39 auto func_graph = input->func_graph();
40 MS_EXCEPTION_IF_NULL(func_graph);
41 AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), input};
42 auto cnode = CreateCNode(inputs, func_graph, {format, GetShape(input), dst_type});
43 SetNodeAttrSafely(kAttrDstType, dst_type, cnode);
44 common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cnode);
45 return cnode;
46 }
47
CreateReduceSum(const AnfNodePtr & node,const AnfNodePtr & input) const48 AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) const {
49 auto cnode = node->cast<CNodePtr>();
50 MS_EXCEPTION_IF_NULL(cnode);
51 cnode->set_input(1, input);
52 cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, GetShape(node)));
53 kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
54 info_builder.SetInputsFormat({AnfAlgo::GetInputFormat(node, 0), AnfAlgo::GetInputFormat(node, 1)});
55 info_builder.SetInputsDeviceType({kFloat32->type_id(), kInt64->type_id()});
56 info_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(node, 0)});
57 info_builder.SetOutputsDeviceType({kFloat32->type_id()});
58 info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
59 info_builder.SetKernelType(KernelType::AKG_KERNEL);
60 info_builder.SetFusionType(kPatternOpaque);
61 AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get());
62 return node;
63 }
64
ReplaceNode(const AnfNodePtr & reduce_node,const AnfNodePtr & cast_node) const65 void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) const {
66 auto mng = reduce_node->func_graph()->manager();
67 MS_EXCEPTION_IF_NULL(mng);
68 // use a copy of user, since the following `mng->Replace` will change the original users of reduce_node.
69 auto users = mng->node_users()[reduce_node];
70 for (const auto &user : users) {
71 auto user_node = user.first;
72 size_t user_index = IntToSize(user.second);
73 if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
74 AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat32) {
75 if (!(mng->Replace(user_node, reduce_node))) {
76 MS_LOG(ERROR) << "Fail to replace node[" << user_node->fullname_with_scope() << "] with node["
77 << reduce_node->fullname_with_scope() << "]";
78 }
79 } else {
80 if (user_node->isa<CNode>()) {
81 user_node->cast<CNodePtr>()->set_input(user_index, cast_node);
82 }
83 }
84 }
85 }
86
Process(const FuncGraphPtr & func_graph) const87 bool RaiseReductionPrecision::Process(const FuncGraphPtr &func_graph) const {
88 auto mng = func_graph->manager();
89 if (mng == nullptr) {
90 mng = Manage(func_graph, true);
91 func_graph->set_manager(mng);
92 }
93 auto todos = TopoSort(func_graph->get_return());
94 bool changed = false;
95 for (auto node : todos) {
96 if (IsFp16ReduceSum(node)) {
97 auto cast1 = CreateCast(node->cast<CNodePtr>()->input(1), kFloat32, AnfAlgo::GetInputFormat(node, 0));
98 auto new_reduce = CreateReduceSum(node, cast1);
99 auto cast2 = CreateCast(new_reduce, kFloat16, AnfAlgo::GetOutputFormat(node, 0));
100 ReplaceNode(node, cast2);
101 changed = true;
102 }
103 }
104 if (changed) {
105 mng->RemoveRoots();
106 mng->KeepRoots({func_graph});
107 }
108 return changed;
109 }
110
Run(const FuncGraphPtr & func_graph)111 bool RaiseReductionPrecision::Run(const FuncGraphPtr &func_graph) {
112 auto mng = func_graph->manager();
113 if (mng == nullptr) {
114 mng = Manage(func_graph, true);
115 func_graph->set_manager(mng);
116 }
117 bool changed = false;
118 auto todos = TopoSort(func_graph->get_return());
119 for (const auto &node : todos) {
120 if (common::AnfAlgo::IsGraphKernel(node)) {
121 auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
122 MS_ERROR_IF_NULL(sub_func_graph);
123 changed = Process(sub_func_graph) || changed;
124 }
125 }
126 if (changed) {
127 GkUtils::UpdateFuncGraphManager(mng, func_graph);
128 }
129 return changed;
130 }
131 } // namespace mindspore::graphkernel
132