• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdexcept>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"  // from @llvm-project
20 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"  // from @llvm-project
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"  // from @llvm-project
22 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"  // from @llvm-project
23 #include "mlir/Conversion/MathToLibm/MathToLibm.h"  // from @llvm-project
24 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"  // from @llvm-project
25 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // from @llvm-project
26 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"  // from @llvm-project
27 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"  // from @llvm-project
28 #include "mlir/Dialect/Complex/IR/Complex.h"  // from @llvm-project
29 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
30 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
31 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"  // from @llvm-project
32 #include "mlir/Dialect/Math/IR/Math.h"  // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
34 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
36 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
38 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
40 
41 namespace mlir {
42 namespace kernel_gen {
43 namespace transforms {
44 namespace {
45 
46 constexpr StringRef kTfWrapperLibaryLaunchHelperName =
47     "_mlir_ciface_tf_launch_kernel";
48 
49 #define GEN_PASS_CLASSES
50 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
51 
52 /// A rewrite patter to convert gpu.launch_func operations into a runtime call
53 /// for the TensorFlow runtime.
54 class ConvertLaunchFuncOpToTfRuntimeCallPattern
55     : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
56  public:
ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter & type_converter,StringRef gpu_binary_annotation)57   ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter &type_converter,
58                                             StringRef gpu_binary_annotation)
59       : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
60         gpu_binary_annotation_(gpu_binary_annotation) {}
61 
62  private:
63   Value generateParamsArray(gpu::LaunchFuncOp launch_op,
64                             ArrayRef<Value> operands, OpBuilder &builder) const;
65   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
66                                    Location loc, OpBuilder &builder) const;
67 
68   LogicalResult matchAndRewrite(
69       gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
70       ConversionPatternRewriter &rewriter) const override;
71 
72   MLIRContext *context_ = &this->getTypeConverter()->getContext();
73 
74   Type llvm_void_type_ = LLVM::LLVMVoidType::get(context_);
75   Type llvm_pointer_type_ =
76       LLVM::LLVMPointerType::get(IntegerType::get(context_, 8));
77   Type llvm_pointer_pointer_type_ =
78       LLVM::LLVMPointerType::get(llvm_pointer_type_);
79   Type llvm_int8_type_ = IntegerType::get(context_, 8);
80   Type llvm_int32_type_ = IntegerType::get(context_, 32);
81   Type llvm_int64_type_ = IntegerType::get(context_, 64);
82   Type llvm_intptr_type_ = IntegerType::get(
83       context_, this->getTypeConverter()->getPointerBitwidth(0));
84 
85   llvm::SmallString<32> gpu_binary_annotation_;
86 };
87 
88 // Creates a struct containing all kernel parameters on the stack and returns
89 // an array of type-erased pointers to the fields of the struct. The array can
90 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
91 // The generated code is essentially as follows:
92 //
93 // %struct = alloca(sizeof(struct { Parameters... }))
94 // %array = alloca(NumParameters * sizeof(void *))
95 // for (i : [0, NumParameters))
96 //   %fieldPtr = llvm.getelementptr %struct[0, i]
97 //   llvm.store parameters[i], %fieldPtr
98 //   %elementPtr = llvm.getelementptr %array[i]
99 //   llvm.store %fieldPtr, %elementPtr
100 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,OpBuilder & builder) const101 Value ConvertLaunchFuncOpToTfRuntimeCallPattern::generateParamsArray(
102     gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
103     OpBuilder &builder) const {
104   auto loc = launch_op.getLoc();
105   auto num_kernel_operands = launch_op.getNumKernelOperands();
106   auto arguments = getTypeConverter()->promoteOperands(
107       loc, launch_op.getOperands().take_back(num_kernel_operands),
108       operands.take_back(num_kernel_operands), builder);
109   auto num_arguments = arguments.size();
110   SmallVector<Type, 4> argument_types;
111   argument_types.reserve(num_arguments);
112   for (auto argument : arguments) argument_types.push_back(argument.getType());
113   auto struct_type = LLVM::LLVMStructType::getNewIdentified(
114       context_, StringRef(), argument_types);
115   auto one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
116                                               builder.getI32IntegerAttr(1));
117   auto struct_ptr = builder.create<LLVM::AllocaOp>(
118       loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
119   auto array_size = builder.create<LLVM::ConstantOp>(
120       loc, llvm_int32_type_, builder.getI32IntegerAttr(num_arguments));
121   auto array_ptr = builder.create<LLVM::AllocaOp>(
122       loc, llvm_pointer_pointer_type_, array_size, /*alignment=*/0);
123   auto zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
124                                                builder.getI32IntegerAttr(0));
125   for (auto en : llvm::enumerate(arguments)) {
126     auto index = builder.create<LLVM::ConstantOp>(
127         loc, llvm_int32_type_, builder.getI32IntegerAttr(en.index()));
128     auto field_ptr = builder.create<LLVM::GEPOp>(
129         loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
130         ArrayRef<Value>{zero, index.getResult()});
131     builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
132     auto element_ptr = builder.create<LLVM::GEPOp>(
133         loc, llvm_pointer_pointer_type_, array_ptr, index.getResult());
134     auto casted =
135         builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type_, field_ptr);
136     builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
137   }
138   return array_ptr;
139 }
140 
141 // Emits LLVM IR to launch a kernel function. Expects the module that contains
142 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
143 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
144 //
145 // %0 = call %binarygetter
146 // %1 = <pointer to kernel function name>
147 // %2 = <see generateParamsArray>
148 // call %tfLaunchKernel(%ctx, %0, %1, <launch_op operands 0..5>, %2)
matchAndRewrite(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const149 LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite(
150     gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
151     ConversionPatternRewriter &rewriter) const {
152   if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
153     return rewriter.notifyMatchFailure(
154         launch_op, "Cannot convert with async dependency or result.");
155   }
156 
157   Location loc = launch_op.getLoc();
158 
159   // Create an LLVM global with CUBIN extracted from the kernel annotation and
160   // obtain a pointer to the first byte in it.
161   auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
162       launch_op, launch_op.getKernelModuleName());
163   assert(kernel_module && "expected a kernel module");
164 
165   auto binary_attr =
166       kernel_module->getAttrOfType<StringAttr>(gpu_binary_annotation_);
167   if (!binary_attr) {
168     kernel_module.emitOpError()
169         << "missing " << gpu_binary_annotation_ << " attribute";
170     return failure();
171   }
172 
173   // Create a global for the module blob.
174   SmallString<128> name_buffer(kernel_module.getName());
175   name_buffer.append("_blob");
176   Value module_blob =
177       LLVM::createGlobalString(loc, rewriter, name_buffer.str(),
178                                binary_attr.getValue(), LLVM::Linkage::Internal);
179 
180   // Make sure the trailing zero is included in the constant.
181   auto kernel_name = launch_op.getKernelName();
182   SmallString<128> kernel_name_buffer(kernel_name);
183   kernel_name_buffer.push_back('\0');
184 
185   // Create a global for the kernel name.
186   SmallString<128> kernel_name_global_name_buffer;
187   auto kernel_name_global_name =
188       (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
189           .toStringRef(kernel_name_global_name_buffer);
190   auto kernel_name_global =
191       LLVM::createGlobalString(loc, rewriter, kernel_name_global_name,
192                                kernel_name_buffer, LLVM::Linkage::Internal);
193 
194   auto adaptor =
195       gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
196 
197   // The TensorFlow OpKernelContext is the first argument of the surrounding
198   // LLVMFunc.
199   Value context_arg =
200       launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
201   auto kernel_params = generateParamsArray(launch_op, operands, rewriter);
202 
203   auto function = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
204       launch_op, kTfWrapperLibaryLaunchHelperName);
205   if (!function) {
206     PatternRewriter::InsertionGuard guard(rewriter);
207     auto function_type = LLVM::LLVMFunctionType::get(
208         llvm_void_type_,
209         {
210             llvm_pointer_type_,         /* void* context */
211             llvm_pointer_type_,         /* void* module_blob */
212             llvm_pointer_type_,         /* void* function_name */
213             llvm_intptr_type_,          /* intptr_t grid_x_dim */
214             llvm_intptr_type_,          /* intptr_t grid_y_dim */
215             llvm_intptr_type_,          /* intptr_t grid_z_dim */
216             llvm_intptr_type_,          /* intptr_t block_x_dim */
217             llvm_intptr_type_,          /* intptr_t block_y_dim */
218             llvm_intptr_type_,          /* intptr_t block_z_dim */
219             llvm_pointer_pointer_type_, /* void **kernel_params */
220         });
221     rewriter.setInsertionPointToStart(
222         launch_op->getParentOfType<ModuleOp>().getBody());
223     function = rewriter.create<LLVM::LLVMFuncOp>(
224         loc, kTfWrapperLibaryLaunchHelperName, function_type);
225   }
226   rewriter.create<LLVM::CallOp>(
227       loc, TypeRange(), rewriter.getSymbolRefAttr(function),
228       ArrayRef<Value>{
229           context_arg, module_blob, kernel_name_global, adaptor.gridSizeX(),
230           adaptor.gridSizeY(), adaptor.gridSizeZ(), adaptor.blockSizeX(),
231           adaptor.blockSizeY(), adaptor.blockSizeZ(), kernel_params});
232 
233   rewriter.eraseOp(launch_op);
234   return success();
235 }
236 
237 class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const238   void getDependentDialects(DialectRegistry &registry) const override {
239     registry.insert<LLVM::LLVMDialect>();
240   }
241 
242  public:
TFKernelToLLVMPass(StringRef blob_annotation)243   explicit TFKernelToLLVMPass(StringRef blob_annotation) {
244     if (!blob_annotation.empty()) {
245       blob_annotation_ = blob_annotation.str();
246     }
247   }
248 
runOnOperation()249   void runOnOperation() override {
250     ModuleOp m = getOperation();
251 
252     // Populate type conversions.
253     MLIRContext *ctx = m.getContext();
254     LLVMTypeConverter type_converter(ctx);
255     type_converter.addConversion([&](tf_framework::OpKernelContextType type) {
256       return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
257     });
258     type_converter.addConversion([&](tf_framework::JITCallableType type) {
259       return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
260     });
261 
262     // Populate patterns.
263     RewritePatternSet patterns(&getContext());
264     populateStdExpandOpsPatterns(patterns);
265     populateMemRefToLLVMConversionPatterns(type_converter, patterns);
266     populateMathToLLVMConversionPatterns(type_converter, patterns);
267     populateStdToLLVMConversionPatterns(type_converter, patterns);
268     populateComplexToLLVMConversionPatterns(type_converter, patterns);
269     populateVectorToLLVMConversionPatterns(type_converter, patterns);
270     populateMathToLibmConversionPatterns(patterns, 0);
271     tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
272                                                               &patterns);
273     patterns.insert<ConvertLaunchFuncOpToTfRuntimeCallPattern>(
274         type_converter, blob_annotation_);
275     // Set target.
276     ConversionTarget target(*ctx);
277     target.addLegalDialect<LLVM::LLVMDialect>();
278     target.addIllegalDialect<StandardOpsDialect, complex::ComplexDialect,
279                              gpu::GPUDialect, tf_framework::TFFrameworkDialect,
280                              math::MathDialect>();
281     target.addIllegalOp<UnrealizedConversionCastOp>();
282     // Mark modules as legal.
283     target.addLegalOp<ModuleOp, gpu::GPUModuleOp>();
284     // Do not look into gpu modules, only consider host-side.
285     target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
286 
287     if (failed(applyFullConversion(m, target, std::move(patterns)))) {
288       signalPassFailure();
289     }
290 
291     // Finally, strip the GPU modules, as they are no longer needed.
292     for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
293       op.erase();
294     }
295   }
296 };
297 
298 }  // namespace
299 
CreateTFKernelToLLVMPass(StringRef blob_annotation)300 std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass(
301     StringRef blob_annotation) {
302   return std::make_unique<TFKernelToLLVMPass>(blob_annotation);
303 }
304 
305 }  // namespace transforms
306 }  // namespace kernel_gen
307 }  // namespace mlir
308