• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
17 
18 #include "absl/memory/memory.h"
19 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"  // from @llvm-project
20 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"  // from @llvm-project
21 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"  // from @llvm-project
22 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"  // from @llvm-project
23 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"  // from @llvm-project
24 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"  // from @llvm-project
25 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // from @llvm-project
26 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
27 #include "mlir/Dialect/GPU/Passes.h"  // from @llvm-project
28 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
29 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"  // from @llvm-project
30 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"  // from @llvm-project
31 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
32 #include "mlir/Dialect/SCF/Passes.h"  // from @llvm-project
33 #include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
34 #include "mlir/IR/Dialect.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Pass/PassManager.h"  // from @llvm-project
37 #include "mlir/Transforms/Bufferize.h"  // from @llvm-project
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
39 #include "mlir/Transforms/Passes.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
41 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
44 #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
45 #include "tensorflow/compiler/xla/util.h"
46 
47 namespace xla {
48 namespace mlir_gpu {
49 
LowerLHLOToGPU(mlir::ModuleOp module,LowerLHLOToGPUOptions options)50 Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) {
51   mlir::PassManager pm(module.getContext());
52   tensorflow::applyTensorflowAndCLOptions(pm);
53 
54   // We have to anticipate later unrolling in tiling to make sure that we get
55   // the requested tiling after unrolling. Compute the new tiling here if
56   // needed.
57   llvm::SmallVector<unsigned, 4> tiling_for_unrolling;
58   llvm::SmallVector<int64_t, 4> as_int64;
59   if (!options.unroll_factors.empty()) {
60     tiling_for_unrolling.reserve(options.tile_sizes.size());
61     for (auto pair : llvm::zip(options.tile_sizes, options.unroll_factors)) {
62       tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
63       as_int64.push_back(std::get<1>(pair));
64     }
65   } else {
66     tiling_for_unrolling.append(options.tile_sizes.begin(),
67                                 options.tile_sizes.end());
68   }
69 
70   // Legalize from HLO to LHLO.
71   pm.addPass(::mlir::mhlo::createLegalizeToLhloPass());
72   // Moving `AllocOp`s and inserting missing `DeallocOp`s
73   pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferHoistingPass());
74   pm.addNestedPass<mlir::FuncOp>(::mlir::createBufferDeallocationPass());
75   // Next, we can strip the outer fusion operation.
76   pm.addNestedPass<mlir::FuncOp>(createFusionOpRemoverPass());
77   // Remove unnecessary LHLO copies.
78   pm.addNestedPass<mlir::FuncOp>(::mlir::createCopyRemovalPass());
79   // Legalize reduce operations directly to GPU dialect.
80   pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLegalizeToGpuPass());
81   // Transform LHLO operations to LinAlg.
82   pm.addNestedPass<mlir::FuncOp>(
83       ::mlir::lmhlo::createLegalizeLhloToLinalgPass());
84   // Fuse linalg operations.
85   pm.addNestedPass<mlir::FuncOp>(::mlir::lmhlo::createLhloFuseLinalgPass(
86       /*use_parallel_loops=*/true, tiling_for_unrolling));
87   // Transform the Linalg operations inside of the loop nest into parallel
88   // loops.
89   pm.addNestedPass<mlir::FuncOp>(
90       ::mlir::createConvertLinalgToParallelLoopsPass());
91   // Canonicalize the code to simplify index computations. This is needed so
92   // that loop bounds have the same value.
93   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
94   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
95   // Fuse the inner-most loops.
96   pm.addNestedPass<mlir::FuncOp>(createFuseInnerParallelLoopsPass());
97   // Run CSE to ensure that loads and stores to the same subview get
98   // recognized as such.
99   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
100   // Forward stores to buffers to loads.
101   pm.addNestedPass<mlir::FuncOp>(createStoreForwardingPass());
102   // Remove now unused temporary buffers.
103   pm.addNestedPass<mlir::FuncOp>(createDeadTempBufferRemovalPass());
104   if (!options.unroll_factors.empty()) {
105     pm.addNestedPass<mlir::FuncOp>(
106         ::mlir::createParallelLoopTilingPass(as_int64));
107   }
108   // Project all loop dimensions to X if necessary.
109   if (options.collapse_parallel_loops) {
110     pm.addNestedPass<mlir::FuncOp>(
111         createParallelLoopCollapsingToFirstDimPass());
112   }
113   // Some basic cleanup.
114   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
115   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
116   // Greedily map the remaining loop to GPU hardware dimensions.
117   pm.addNestedPass<::mlir::FuncOp>(createMapParallelLoopsPass());
118   // Apply the mapping.
119   pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
120   // Some basic cleanup.
121   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
122   pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
123   // Make loops with min bounds into a conditional plus static bounds.
124   // Only do this if we unrolled in the first place.
125   if (!options.unroll_factors.empty()) {
126     pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass());
127   }
128   // Approximate of requested.
129   if (options.use_approximations) {
130     pm.addNestedPass<::mlir::FuncOp>(
131         ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass());
132   }
133   // Take launches to launches with kernels.
134   pm.addPass(::mlir::createGpuKernelOutliningPass());
135   // Make sure the kernel signature resembled the original function's
136   // signature
137   if (options.rewrite_signature) {
138     pm.addNestedPass<::mlir::FuncOp>(createRewriteKernelSignaturePass());
139   }
140   if (failed(pm.run(module))) {
141     return InternalError("Lowering to GPU kernels failed.");
142   }
143   return Status::OK();
144 }
145 
146 namespace {
147 
148 /// A pass that does the final lowering to NVVM. It collects all the patterns
149 /// that are currently required, currently mixing std, linalg and gpu.
150 class LowerToNVVMPass
151     : public ::mlir::PassWrapper<
152           LowerToNVVMPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const153   void getDependentDialects(mlir::DialectRegistry& registry) const override {
154     registry.insert<mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
155   }
156 
157  public:
runOnOperation()158   void runOnOperation() override {
159     ::mlir::gpu::GPUModuleOp m = getOperation();
160 
161     ::mlir::OwningRewritePatternList patterns;
162     ::mlir::LLVMTypeConverter converter(m.getContext());
163     ::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
164     // TODO(b/145824979) Remove linalg once sliceop is in std.
165     ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns);
166     ::mlir::populateGpuToNVVMConversionPatterns(converter, patterns);
167     ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
168     ::mlir::ConversionTarget target(getContext());
169     ::mlir::configureGpuToNVVMConversionLegality(target);
170     if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
171       signalPassFailure();
172     }
173   }
174 };
175 
176 }  // namespace
177 
LowerKernelBodiesToNVVM(mlir::ModuleOp module)178 Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
179   ::mlir::PassManager pm(module.getContext());
180   // We cannot verify as the signature of the kernel is rewritten.
181   pm.enableVerifier(false);
182   tensorflow::applyTensorflowAndCLOptions(pm);
183 
184   // Rewrite kernel functions to LLVM IR.
185   auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
186   kernelPm.addPass(::mlir::createLowerToCFGPass());
187   kernelPm.addPass(absl::make_unique<LowerToNVVMPass>());
188   // Some basic cleanup.
189   kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
190   kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
191   // Remove all location information to prevent a debug build.
192   pm.addPass(::mlir::createStripDebugInfoPass());
193 
194   if (failed(pm.run(module))) {
195     return InternalError("Lowering to NVVM IR failed.");
196   }
197   return Status::OK();
198 }
199 
200 namespace {
201 
202 /// A pass that does the final lowering to ROCDL. It collects all the patterns
203 /// that are currently required, currently mixing std, linalg and gpu.
204 class LowerToROCDLPass
205     : public ::mlir::PassWrapper<
206           LowerToROCDLPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const207   void getDependentDialects(mlir::DialectRegistry& registry) const override {
208     registry.insert<mlir::ROCDL::ROCDLDialect, mlir::LLVM::LLVMDialect>();
209   }
210 
211  public:
runOnOperation()212   void runOnOperation() override {
213     ::mlir::gpu::GPUModuleOp m = getOperation();
214 
215     {
216       ::mlir::OwningRewritePatternList patterns;
217       ::mlir::populateGpuRewritePatterns(m.getContext(), patterns);
218       ::mlir::applyPatternsAndFoldGreedily(m, std::move(patterns));
219     }
220 
221     ::mlir::OwningRewritePatternList patterns;
222     ::mlir::LLVMTypeConverter converter(m.getContext());
223     ::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
224     // TODO(b/145824979) Remove linalg once sliceop is in std.
225     ::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns);
226     ::mlir::populateGpuToROCDLConversionPatterns(converter, patterns);
227     ::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
228 
229     ::mlir::ConversionTarget target(getContext());
230     ::mlir::configureGpuToROCDLConversionLegality(target);
231     if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
232       signalPassFailure();
233     }
234   }
235 };
236 
237 }  // namespace
238 
LowerKernelBodiesToROCDL(mlir::ModuleOp module)239 Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) {
240   ::mlir::PassManager pm(module.getContext());
241   // We cannot verify as the signature of the kernel is rewritten.
242   pm.enableVerifier(false);
243   tensorflow::applyTensorflowAndCLOptions(pm);
244 
245   auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) {
246     return VLOG_IS_ON(1);
247   };
248   pm.enableIRPrinting(/*shouldPrintBeforePass=*/{},
249                       /*shouldPrintAfterPass=*/enable_if_vlog_is_on,
250                       /*printModuleScope=*/false,
251                       /*printAfterOnlyOnChange=*/false,
252                       /*out=*/llvm::dbgs());
253 
254   // Rewrite kernel functions to LLVM IR.
255   auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>();
256   kernelPm.addPass(::mlir::createLowerToCFGPass());
257   kernelPm.addPass(absl::make_unique<LowerToROCDLPass>());
258 
259   // Some basic cleanup.
260   kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
261   kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
262   // Remove all location information to prevent a debug build.
263   kernelPm.addPass(::mlir::createStripDebugInfoPass());
264 
265   if (failed(pm.run(module))) {
266     return InternalError("Lowering to ROCDL IR failed.");
267   }
268   return Status::OK();
269 }
270 
ExtractKernelModule(mlir::ModuleOp module)271 StatusOr<mlir::ModuleOp> ExtractKernelModule(mlir::ModuleOp module) {
272   auto kernelModule = ::mlir::ModuleOp::create(module.getLoc());
273   // TODO(b/137624192): This also needs to resolve naming conflicts.
274   module.walk([&kernelModule](mlir::gpu::GPUModuleOp nestedModule) {
275     for (auto& fn : nestedModule.body().front()) {
276       kernelModule.push_back(fn.clone());
277     }
278   });
279   return kernelModule;
280 }
281 
282 }  // namespace mlir_gpu
283 }  // namespace xla
284