1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
17
18 #include "absl/memory/memory.h"
19 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
20 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
21 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project
22 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
23 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project
24 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
25 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
26 #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
27 #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
28 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
29 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
30 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" // from @llvm-project
31 #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
32 #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
33 #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
34 #include "mlir/IR/Dialect.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Pass/PassManager.h" // from @llvm-project
37 #include "mlir/Transforms/Bufferize.h" // from @llvm-project
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
39 #include "mlir/Transforms/Passes.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
41 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
44 #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
45 #include "tensorflow/compiler/xla/util.h"
46
47 namespace xla {
48 namespace mlir_gpu {
49
LowerLHLOToGPU(mlir::ModuleOp module,LowerLHLOToGPUOptions options)50 Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) {
51 mlir::PassManager pm(module.getContext());
52 tensorflow::applyTensorflowAndCLOptions(pm);
53
54 // We have to anticipate later unrolling in tiling to make sure that we get
55 // the requested tiling after unrolling. Compute the new tiling here if
56 // needed.
57 llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
58 llvm::SmallVector<int64_t, 4> as_int64;
59 if (!options.unroll_factors.empty()) {
60 tiling_for_unrolling.reserve(options.tile_sizes.size());
61 for (auto pair : llvm::zip(options.tile_sizes, options.unroll_factors)) {
62 tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
63 as_int64.push_back(std::get<1>(pair));
64 }
65 } else {
66 tiling_for_unrolling.append(options.tile_sizes.begin(),
67 options.tile_sizes.end());
68 }
69
70 // Legalize from HLO to LHLO.
71 pm.addPass(::mlir::mhlo::createLegalizeToLhloPass());
72 // Moving `AllocOp`s and inserting missing `DeallocOp`s
73 pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferHoistingPass());
74 pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
75 // Next, we can strip the outer fusion operation.
76 pm.addNestedPass<mlir::FuncOp>(createFusionOpRemoverPass());
77 // Remove unnecessary LHLO copies.
78 pm.addNestedPass<mlir::FuncOp>(::mlir::createCopyRemovalPass());
79 // Legalize reduce operations directly to GPU dialect.
80 pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLegalizeToGpuPass());
81 // Transform LHLO operations to LinAlg.
82 pm.addNestedPass<mlir::FuncOp>(
83 ::mlir::lmhlo::createLegalizeLhloToLinalgPass());
84 // Fuse linalg operations.
85 pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
86 /*use_parallel_loops=*/true, tiling_for_unrolling));
87 // Transform the Linalg operations inside of the loop nest into parallel
88 // loops.
89 pm.addNestedPass<mlir::FuncOp>(
90 ::mlir::createConvertLinalgToParallelLoopsPass());
91 // Canonicalize the code to simplify index computations. This is needed so
92 // that loop bounds have the same value.
93 pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
94 pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
95 // Fuse the inner-most loops.
96 pm.addNestedPass<mlir::FuncOp>(createFuseInnerParallelLoopsPass());
97 // Run CSE to ensure that loads and stores to the same subview get
98 // recognized as such.
99 pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
100 // Forward stores to buffers to loads.
101 pm.addNestedPass<mlir::FuncOp>(createStoreForwardingPass());
102 // Remove now unused temporary buffers.
103 pm.addNestedPass<mlir::FuncOp>(createDeadTempBufferRemovalPass());
104 if (!options.unroll_factors.empty()) {
105 pm.addNestedPass<mlir::FuncOp>(
106 ::mlir::createParallelLoopTilingPass(as_int64));
107 }
108 // Project all loop dimensions to X if necessary.
109 if (options.collapse_parallel_loops) {
110 pm.addNestedPass<mlir::FuncOp>(
111 createParallelLoopCollapsingToFirstDimPass());
112 }
113 // Some basic cleanup.
114 pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
115 pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
116 // Greedily map the remaining loop to GPU hardware dimensions.
117 pm.addNestedPass<::mlir::FuncOp>(createMapParallelLoopsPass());
118 // Apply the mapping.
119 pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
120 // Some basic cleanup.
121 pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
122 pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
123 // Make loops with min bounds into a conditional plus static bounds.
124 // Only do this if we unrolled in the first place.
125 if (!options.unroll_factors.empty()) {
126 pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass());
127 }
128 // Approximate of requested.
129 if (options.use_approximations) {
130 pm.addNestedPass<::mlir::FuncOp>(
131 ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass());
132 }
133 // Take launches to launches with kernels.
134 pm.addPass(::mlir::createGpuKernelOutliningPass());
135 // Make sure the kernel signature resembled the original function's
136 // signature
137 if (options.rewrite_signature) {
138 pm.addNestedPass<::mlir::FuncOp>(createRewriteKernelSignaturePass());
139 }
140 if (failed(pm.run(module))) {
141 return InternalError("Lowering to GPU kernels failed.");
142 }
143 return Status::OK();
144 }
145
146 namespace {
147
148 /// A pass that does the final lowering to NVVM. It collects all the patterns
149 /// that are currently required, currently mixing std, linalg and gpu.
150 class LowerToNVVMPass
151 : public ::mlir::PassWrapper<
152 LowerToNVVMPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const153 void getDependentDialects(mlir::DialectRegistry& registry) const override {
154 registry.insert<mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
155 }
156
157 public:
runOnOperation()158 void runOnOperation() override {
159 ::mlir::gpu::GPUModuleOp m = getOperation();
160
161 ::mlir::OwningRewritePatternList patterns;
162 ::mlir::LLVMTypeConverter converter(m.getContext());
163 ::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
164 // TODO(b/145824979) Remove linalg once sliceop is in std.
165 ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns);
166 ::mlir::populateGpuToNVVMConversionPatterns(converter, patterns);
167 ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
168 ::mlir::ConversionTarget target(getContext());
169 ::mlir::configureGpuToNVVMConversionLegality(target);
170 if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
171 signalPassFailure();
172 }
173 }
174 };
175
176 } // namespace
177
LowerKernelBodiesToNVVM(mlir::ModuleOp module)178 Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
179 ::mlir::PassManager pm(module.getContext());
180 // We cannot verify as the signature of the kernel is rewritten.
181 pm.enableVerifier(false);
182 tensorflow::applyTensorflowAndCLOptions(pm);
183
184 // Rewrite kernel functions to LLVM IR.
185 auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
186 kernelPm.addPass(::mlir::createLowerToCFGPass());
187 kernelPm.addPass(absl::make_unique<LowerToNVVMPass>());
188 // Some basic cleanup.
189 kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
190 kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
191 // Remove all location information to prevent a debug build.
192 pm.addPass(::mlir::createStripDebugInfoPass());
193
194 if (failed(pm.run(module))) {
195 return InternalError("Lowering to NVVM IR failed.");
196 }
197 return Status::OK();
198 }
199
200 namespace {
201
202 /// A pass that does the final lowering to ROCDL. It collects all the patterns
203 /// that are currently required, currently mixing std, linalg and gpu.
204 class LowerToROCDLPass
205 : public ::mlir::PassWrapper<
206 LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const207 void getDependentDialects(mlir::DialectRegistry& registry) const override {
208 registry.insert<mlir::ROCDL::ROCDLDialect, mlir::LLVM::LLVMDialect>();
209 }
210
211 public:
runOnOperation()212 void runOnOperation() override {
213 ::mlir::gpu::GPUModuleOp m = getOperation();
214
215 {
216 ::mlir::OwningRewritePatternList patterns;
217 ::mlir::populateGpuRewritePatterns(m.getContext(), patterns);
218 ::mlir::applyPatternsAndFoldGreedily(m, std::move(patterns));
219 }
220
221 ::mlir::OwningRewritePatternList patterns;
222 ::mlir::LLVMTypeConverter converter(m.getContext());
223 ::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
224 // TODO(b/145824979) Remove linalg once sliceop is in std.
225 ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns);
226 ::mlir::populateGpuToROCDLConversionPatterns(converter, patterns);
227 ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
228
229 ::mlir::ConversionTarget target(getContext());
230 ::mlir::configureGpuToROCDLConversionLegality(target);
231 if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
232 signalPassFailure();
233 }
234 }
235 };
236
237 } // namespace
238
LowerKernelBodiesToROCDL(mlir::ModuleOp module)239 Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) {
240 ::mlir::PassManager pm(module.getContext());
241 // We cannot verify as the signature of the kernel is rewritten.
242 pm.enableVerifier(false);
243 tensorflow::applyTensorflowAndCLOptions(pm);
244
245 auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
246 return VLOG_IS_ON(1);
247 };
248 pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
249 /*shouldPrintAfterPass=*/enable_if_vlog_is_on,
250 /*printModuleScope=*/false,
251 /*printAfterOnlyOnChange=*/false,
252 /*out=*/llvm::dbgs());
253
254 // Rewrite kernel functions to LLVM IR.
255 auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
256 kernelPm.addPass(::mlir::createLowerToCFGPass());
257 kernelPm.addPass(absl::make_unique<LowerToROCDLPass>());
258
259 // Some basic cleanup.
260 kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
261 kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
262 // Remove all location information to prevent a debug build.
263 kernelPm.addPass(::mlir::createStripDebugInfoPass());
264
265 if (failed(pm.run(module))) {
266 return InternalError("Lowering to ROCDL IR failed.");
267 }
268 return Status::OK();
269 }
270
ExtractKernelModule(mlir::ModuleOp module)271 StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module) {
272 auto kernelModule = ::mlir::ModuleOp::create(module.getLoc());
273 // TODO(b/137624192): This also needs to resolve naming conflicts.
274 module.walk([&kernelModule](mlir::gpu::GPUModuleOp nestedModule) {
275 for (auto& fn : nestedModule.body().front()) {
276 kernelModule.push_back(fn.clone());
277 }
278 });
279 return kernelModule;
280 }
281
282 } // namespace mlir_gpu
283 } // namespace xla
284