1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
17
18 #include <vector>
19 #include <string>
20 #include <memory>
21
22 #include "ir/func_graph.h"
23 #include "utils/ms_context.h"
24 #include "utils/context/graph_kernel_flags.h"
25 #include "backend/optimizer/graph_kernel/add_atomic_clean.h"
26 #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
27 #include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
28 #include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
29 #include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
30 #include "backend/optimizer/graph_kernel/insert_pad.h"
31 #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
32 #include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
33 #include "backend/optimizer/graph_kernel/cast_matmul_fusion.h"
34 #include "backend/optimizer/graph_kernel/raise_reduction_precision.h"
35 #include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
36 #include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
37 #include "backend/optimizer/graph_kernel/value_graph_binder.h"
38 #include "backend/optimizer/graph_kernel/parallel_fusion.h"
39 #include "backend/optimizer/graph_kernel/optimize_assign.h"
40 #include "backend/optimizer/graph_kernel/split_umonad.h"
41 #include "backend/optimizer/graph_kernel/reorder_ops.h"
42 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
43 #include "backend/optimizer/graph_kernel/axis_normalizer.h"
44 #include "backend/optimizer/graph_kernel/decrease_compute_precision.h"
45 #include "backend/optimizer/graph_kernel/decrease_transfer_precision.h"
46 #include "backend/optimizer/graph_kernel/tsa_atomic_add_to_first_tensor.h"
47 #include "backend/optimizer/graph_kernel/uss_atomic_add.h"
48 #include "backend/optimizer/pass/getitem_tuple.h"
49 #include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h"
50 #include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
51 #include "backend/optimizer/graph_kernel/rewrite_output_shape.h"
52
53 namespace mindspore {
54 namespace opt {
55 using context::OptLevel_1;
56 using context::OptLevel_2;
57 using context::OptLevel_3;
58 using context::OptLevel_MAX;
59 namespace {
GetPassLevelByFlag(bool flag)60 inline unsigned int GetPassLevelByFlag(bool flag) { return flag ? OptLevel_1 : OptLevel_MAX; }
61 } // namespace
62
PreProcess() const63 PassManagerPtr GraphKernelOptimizer::PreProcess() const {
64 auto pm = std::make_shared<GraphKernelPassManager>(0, "preprocess");
65 // Do cse before all passes of graphkernel
66 pm->AddPass(std::make_shared<CommonSubexpressionElimination>("cse1"), OptLevel_1);
67
68 // Save the original output info
69 pm->AddPass(std::make_shared<SaveOutputShape>(), OptLevel_1);
70
71 // Change Assign(p, a, U) to Assign(Depend(p, U), a)
72 pm->AddPass(std::make_shared<SplitAssign>(), OptLevel_1, is_gpu);
73
74 // Spread the MakeTuple input of UpdateState
75 pm->AddPass(std::make_shared<SpreadUpdateState>(), OptLevel_1);
76 // Eliminate the common nodes that generated in SpreadUpdateState
77 pm->AddPass(std::make_shared<CommonSubexpressionElimination>("cse2"), OptLevel_1);
78 return pm;
79 }
80
Cluster() const81 PassManagerPtr GraphKernelOptimizer::Cluster() const {
82 auto pm = std::make_shared<GraphKernelPassManager>(1, "cluster");
83
84 // Expand complex op to composite kernels
85 pm->AddPass(std::make_shared<GraphKernelComplexExpander>(), OptLevel_1, is_gpu);
86
87 // Expand complex basic kernels to composite kernels
88 pm->AddPass(std::make_shared<GraphKernelExpander>(), OptLevel_1);
89
90 // Cluster basic kernels and composite kernels
91 pm->AddPass(std::make_shared<GraphKernelCluster>(), OptLevel_1);
92
93 // Eliminate the outputs without external user
94 pm->AddPass(std::make_shared<EliminateRedundantOutput>(), OptLevel_1);
95 return pm;
96 }
97
HighLevelOpt1() const98 PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
99 auto pm = std::make_shared<GraphKernelPassManager>(2, "highlevelopt1");
100
101 // Remove redundant Cast(bias, fp16) for Matmul input
102 pm->AddPass(std::make_shared<CastMatmulFusion>(), OptLevel_2, is_ascend);
103
104 // Reorder Cast and Type-insensitive node
105 pm->AddPass(std::make_shared<ReorderOps>(), OptLevel_2);
106
107 // normalize the Reduce axis
108 pm->AddPass(std::make_shared<AxisNormalizer>(), OptLevel_1);
109
110 // Replace Assign with InplaceAssign, and replace original output with overridden parameters
111 pm->AddPass(std::make_shared<OptimizeAssign>(), OptLevel_2);
112 pm->AddPass(std::make_shared<EliminateRedundantOutput>(), OptLevel_2);
113
114 // Cast the input of ReduceSum from float16 to float32 for higher precision
115 pm->AddPass(std::make_shared<RaiseReductionPrecision>(), OptLevel_2);
116
117 // Insert PadAkg and UnPadAkg Ops for MatMul
118 pm->AddPass(std::make_shared<InsertPadOps>(), OptLevel_1, is_gpu);
119
120 // Universal arithmetic simplify
121 pm->AddPass(std::make_shared<ArithmeticSimplify>(), OptLevel_2, is_gpu);
122
123 // Common subexpression elimination
124 pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
125
126 // Eliminate unnecessary transform ops
127 auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize);
128 pm->AddPass(std::make_shared<TransformOpOptimizer>(), level, is_gpu);
129 return pm;
130 }
131
Split() const132 PassManagerPtr GraphKernelOptimizer::Split() const {
133 auto pm = std::make_shared<GraphKernelPassManager>(3, "split");
134 // Make certain nodes redundant so that they are used by only one user,
135 // which can avoid unnecessary input-output and get better performance.
136 // preprocess for ShapeOpsSplitter
137 pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), OptLevel_1);
138 std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape};
139 pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops), OptLevel_1);
140
141 // Split kernel according to costmodel
142 pm->AddPass(std::make_shared<GraphKernelSplitter>(), OptLevel_1);
143
144 // After Simplify and Splitter, a lot of redundant getitem/maketuple
145 // will be exposed, use GetitemTuple Pass to delete them.
146 pm->AddPass(std::make_shared<GetitemTuple>(), OptLevel_1);
147
148 // Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter
149 pm->AddPass(std::make_shared<MergeOutputForUpdateState>(), OptLevel_1);
150 pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_1);
151 pm->AddPass(std::make_shared<EliminateRedundantOutput>(), OptLevel_1);
152 return pm;
153 }
154
HighLevelOpt2() const155 PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
156 auto pm = std::make_shared<GraphKernelPassManager>(4, "highlevelopt2");
157 // Enable atomic add
158 pm->AddPass(std::make_shared<AtomicCleanInsertter>(), OptLevel_2);
159
160 // Enable atomic add for stitch nodes.
161 auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_stitch_fusion);
162 pm->AddPass(std::make_shared<StitchAtomicCleanInsertter>(), level, is_gpu);
163
164 // Enable low precision
165 auto level_low_precision = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_low_precision);
166 pm->AddPass(std::make_shared<DecreaseTransferPrecision>(), level_low_precision);
167 pm->AddPass(std::make_shared<DecreaseComputePrecision>(), level_low_precision, is_ascend);
168
169 // Enable tsa and uss
170 pm->AddPass(std::make_shared<TsaAtomicAddToFirstTensor>(), OptLevel_1);
171 pm->AddPass(std::make_shared<UssAtomicAdd>(), OptLevel_1);
172
173 return pm;
174 }
175
Combine() const176 PassManagerPtr GraphKernelOptimizer::Combine() const {
177 auto pm = std::make_shared<GraphKernelPassManager>(5, "combine");
178 // Enable parallel fusion for gpu device
179 auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_parallel_fusion);
180 pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7)), level, is_gpu);
181
182 return pm;
183 }
184
PostProcess() const185 PassManagerPtr GraphKernelOptimizer::PostProcess() const {
186 auto pm = std::make_shared<GraphKernelPassManager>(6, "postprocess");
187 // Make Tuple for the inputs of UpdateState. (the reverse of SpreadUpdateState)
188 pm->AddPass(std::make_shared<ShrinkUpdateState>(), OptLevel_1);
189
190 // Recover the original output info
191 pm->AddPass(std::make_shared<GetitemTuple>(), OptLevel_1);
192 pm->AddPass(std::make_shared<RewriteOutputShape>(), OptLevel_1);
193
194 // Add the new tensors to the kernel_graph
195 pm->AddPass(std::make_shared<BindValueToGraph>(), OptLevel_1);
196 return pm;
197 }
198
Run(const KernelGraphPtr & kernel_graph)199 void GraphKernelOptimizer::Run(const KernelGraphPtr &kernel_graph) {
200 auto context_ptr = MsContext::GetInstance();
201 MS_EXCEPTION_IF_NULL(context_ptr);
202 is_gpu = (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
203 is_ascend = (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
204
205 auto optimizer = std::make_shared<GraphOptimizer>("graph_kernel_optimizer");
206 optimizer->AddPassManager(PreProcess());
207 optimizer->AddPassManager(Cluster());
208 optimizer->AddPassManager(HighLevelOpt1());
209 optimizer->AddPassManager(Split());
210 optimizer->AddPassManager(HighLevelOpt2());
211 optimizer->AddPassManager(Combine());
212 optimizer->AddPassManager(PostProcess());
213
214 auto mng = kernel_graph->manager();
215 if (mng == nullptr) {
216 mng = Manage(kernel_graph, true);
217 kernel_graph->set_manager(mng);
218 }
219 (void)optimizer->Optimize(kernel_graph);
220 }
221
GraphKernelOptimize(const KernelGraphPtr & kernel_graph)222 void GraphKernelOptimize(const KernelGraphPtr &kernel_graph) { GraphKernelOptimizer().Run(kernel_graph); }
223 } // namespace opt
224 } // namespace mindspore
225