• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/raise_reduction_precision.h"
17 #include "mindspore/core/ops/math_ops.h"
18 #include "mindspore/core/ops/array_ops.h"
19 #include "include/common/utils/utils.h"
20 #include "include/backend/optimizer/helper.h"
21 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
22 #include "backend/common/graph_kernel/graph_kernel_helper.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "kernel/kernel_build_info.h"
26 #include "include/backend/kernel_info.h"
27 
28 namespace mindspore::graphkernel {
29 namespace {
30 constexpr auto kPatternOpaque = "Opaque";
31 }
32 
IsFp16ReduceSum(const AnfNodePtr & node) const33 bool RaiseReductionPrecision::IsFp16ReduceSum(const AnfNodePtr &node) const {
34   return IsPrimitiveCNode(node, prim::kPrimReduceSum) && AnfAlgo::GetInputDeviceDataType(node, 0) == kNumberTypeFloat16;
35 }
36 
CreateCast(const AnfNodePtr & input,const TypePtr & dst_type,const std::string & format) const37 AnfNodePtr RaiseReductionPrecision::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type,
38                                                const std::string &format) const {
39   auto func_graph = input->func_graph();
40   MS_EXCEPTION_IF_NULL(func_graph);
41   AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), input};
42   auto cnode = CreateCNode(inputs, func_graph, {format, GetShape(input), dst_type});
43   SetNodeAttrSafely(kAttrDstType, dst_type, cnode);
44   common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cnode);
45   return cnode;
46 }
47 
CreateReduceSum(const AnfNodePtr & node,const AnfNodePtr & input) const48 AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, const AnfNodePtr &input) const {
49   auto cnode = node->cast<CNodePtr>();
50   MS_EXCEPTION_IF_NULL(cnode);
51   cnode->set_input(1, input);
52   cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, GetShape(node)));
53   kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
54   info_builder.SetInputsFormat({AnfAlgo::GetInputFormat(node, 0), AnfAlgo::GetInputFormat(node, 1)});
55   info_builder.SetInputsDeviceType({kFloat32->type_id(), kInt64->type_id()});
56   info_builder.SetOutputsFormat({AnfAlgo::GetOutputFormat(node, 0)});
57   info_builder.SetOutputsDeviceType({kFloat32->type_id()});
58   info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
59   info_builder.SetKernelType(KernelType::AKG_KERNEL);
60   info_builder.SetFusionType(kPatternOpaque);
61   AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get());
62   return node;
63 }
64 
ReplaceNode(const AnfNodePtr & reduce_node,const AnfNodePtr & cast_node) const65 void RaiseReductionPrecision::ReplaceNode(const AnfNodePtr &reduce_node, const AnfNodePtr &cast_node) const {
66   auto mng = reduce_node->func_graph()->manager();
67   MS_EXCEPTION_IF_NULL(mng);
68   // use a copy of user, since the following `mng->Replace` will change the original users of reduce_node.
69   auto users = mng->node_users()[reduce_node];
70   for (const auto &user : users) {
71     auto user_node = user.first;
72     size_t user_index = IntToSize(user.second);
73     if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
74         AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat32) {
75       if (!(mng->Replace(user_node, reduce_node))) {
76         MS_LOG(ERROR) << "Fail to replace node[" << user_node->fullname_with_scope() << "] with node["
77                       << reduce_node->fullname_with_scope() << "]";
78       }
79     } else {
80       if (user_node->isa<CNode>()) {
81         user_node->cast<CNodePtr>()->set_input(user_index, cast_node);
82       }
83     }
84   }
85 }
86 
Process(const FuncGraphPtr & func_graph) const87 bool RaiseReductionPrecision::Process(const FuncGraphPtr &func_graph) const {
88   auto mng = func_graph->manager();
89   if (mng == nullptr) {
90     mng = Manage(func_graph, true);
91     func_graph->set_manager(mng);
92   }
93   auto todos = TopoSort(func_graph->get_return());
94   bool changed = false;
95   for (auto node : todos) {
96     if (IsFp16ReduceSum(node)) {
97       auto cast1 = CreateCast(node->cast<CNodePtr>()->input(1), kFloat32, AnfAlgo::GetInputFormat(node, 0));
98       auto new_reduce = CreateReduceSum(node, cast1);
99       auto cast2 = CreateCast(new_reduce, kFloat16, AnfAlgo::GetOutputFormat(node, 0));
100       ReplaceNode(node, cast2);
101       changed = true;
102     }
103   }
104   if (changed) {
105     mng->RemoveRoots();
106     mng->KeepRoots({func_graph});
107   }
108   return changed;
109 }
110 
Run(const FuncGraphPtr & func_graph)111 bool RaiseReductionPrecision::Run(const FuncGraphPtr &func_graph) {
112   auto mng = func_graph->manager();
113   if (mng == nullptr) {
114     mng = Manage(func_graph, true);
115     func_graph->set_manager(mng);
116   }
117   bool changed = false;
118   auto todos = TopoSort(func_graph->get_return());
119   for (const auto &node : todos) {
120     if (common::AnfAlgo::IsGraphKernel(node)) {
121       auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
122       MS_ERROR_IF_NULL(sub_func_graph);
123       changed = Process(sub_func_graph) || changed;
124     }
125   }
126   if (changed) {
127     GkUtils::UpdateFuncGraphManager(mng, func_graph);
128   }
129   return changed;
130 }
131 }  // namespace mindspore::graphkernel
132