• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdexcept>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"  // from @llvm-project
20 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"  // from @llvm-project
21 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // from @llvm-project
22 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"  // from @llvm-project
23 #include "mlir/Dialect/Complex/IR/Complex.h"  // from @llvm-project
24 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
26 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"  // from @llvm-project
27 #include "mlir/Dialect/Math/IR/Math.h"  // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
35 
36 namespace mlir {
37 namespace kernel_gen {
38 namespace transforms {
39 namespace {
40 
41 constexpr StringRef kTfWrapperLibaryLaunchHelperName =
42     "tfKernelGenLaunchKernel";
43 
44 #define GEN_PASS_CLASSES
45 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
46 
47 /// A rewrite patter to convert gpu.launch_func operations into a runtime call
48 /// for the TensorFlow runtime.
49 class ConvertLaunchFuncOpToTfRuntimeCallPattern
50     : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
51  public:
ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter & type_converter,StringRef gpu_binary_annotation)52   ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter &type_converter,
53                                             StringRef gpu_binary_annotation)
54       : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
55         gpu_binary_annotation_(gpu_binary_annotation) {}
56 
57  private:
58   Value generateParamsArray(gpu::LaunchFuncOp launch_op,
59                             ArrayRef<Value> operands, OpBuilder &builder) const;
60   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
61                                    Location loc, OpBuilder &builder) const;
62 
63   LogicalResult matchAndRewrite(
64       gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
65       ConversionPatternRewriter &rewriter) const override;
66 
67   MLIRContext *context_ = &this->getTypeConverter()->getContext();
68 
69   Type llvm_void_type_ = LLVM::LLVMVoidType::get(context_);
70   Type llvm_pointer_type_ =
71       LLVM::LLVMPointerType::get(IntegerType::get(context_, 8));
72   Type llvm_pointer_pointer_type_ =
73       LLVM::LLVMPointerType::get(llvm_pointer_type_);
74   Type llvm_int8_type_ = IntegerType::get(context_, 8);
75   Type llvm_int32_type_ = IntegerType::get(context_, 32);
76   Type llvm_int64_type_ = IntegerType::get(context_, 64);
77   Type llvm_intptr_type_ = IntegerType::get(
78       context_, this->getTypeConverter()->getPointerBitwidth(0));
79 
80   llvm::SmallString<32> gpu_binary_annotation_;
81 };
82 
83 // Creates a struct containing all kernel parameters on the stack and returns
84 // an array of type-erased pointers to the fields of the struct. The array can
85 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
86 // The generated code is essentially as follows:
87 //
88 // %struct = alloca(sizeof(struct { Parameters... }))
89 // %array = alloca(NumParameters * sizeof(void *))
90 // for (i : [0, NumParameters))
91 //   %fieldPtr = llvm.getelementptr %struct[0, i]
92 //   llvm.store parameters[i], %fieldPtr
93 //   %elementPtr = llvm.getelementptr %array[i]
94 //   llvm.store %fieldPtr, %elementPtr
95 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,OpBuilder & builder) const96 Value ConvertLaunchFuncOpToTfRuntimeCallPattern::generateParamsArray(
97     gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
98     OpBuilder &builder) const {
99   auto loc = launch_op.getLoc();
100   auto num_kernel_operands = launch_op.getNumKernelOperands();
101   auto arguments = getTypeConverter()->promoteOperands(
102       loc, launch_op.getOperands().take_back(num_kernel_operands),
103       operands.take_back(num_kernel_operands), builder);
104   auto num_arguments = arguments.size();
105   SmallVector<Type, 4> argument_types;
106   argument_types.reserve(num_arguments);
107   for (auto argument : arguments) argument_types.push_back(argument.getType());
108   auto struct_type = LLVM::LLVMStructType::getNewIdentified(
109       context_, StringRef(), argument_types);
110   auto one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
111                                               builder.getI32IntegerAttr(1));
112   auto struct_ptr = builder.create<LLVM::AllocaOp>(
113       loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
114   auto array_size = builder.create<LLVM::ConstantOp>(
115       loc, llvm_int32_type_, builder.getI32IntegerAttr(num_arguments));
116   auto array_ptr = builder.create<LLVM::AllocaOp>(
117       loc, llvm_pointer_pointer_type_, array_size, /*alignment=*/0);
118   auto zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
119                                                builder.getI32IntegerAttr(0));
120   for (auto en : llvm::enumerate(arguments)) {
121     auto index = builder.create<LLVM::ConstantOp>(
122         loc, llvm_int32_type_, builder.getI32IntegerAttr(en.index()));
123     auto field_ptr = builder.create<LLVM::GEPOp>(
124         loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
125         ArrayRef<Value>{zero, index.getResult()});
126     builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
127     auto element_ptr = builder.create<LLVM::GEPOp>(
128         loc, llvm_pointer_pointer_type_, array_ptr, index.getResult());
129     auto casted =
130         builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type_, field_ptr);
131     builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
132   }
133   return array_ptr;
134 }
135 
136 // Emits LLVM IR to launch a kernel function. Expects the module that contains
137 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
138 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
139 //
140 // %0 = call %binarygetter
141 // %1 = <pointer to kernel function name>
142 // %2 = <see generateParamsArray>
143 // call %tfLaunchKernel(%ctx, %0, %1, <launch_op operands 0..5>, %2)
matchAndRewrite(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const144 LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite(
145     gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
146     ConversionPatternRewriter &rewriter) const {
147   if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
148     return rewriter.notifyMatchFailure(
149         launch_op, "Cannot convert with async dependency or result.");
150   }
151 
152   Location loc = launch_op.getLoc();
153 
154   // Create an LLVM global with CUBIN extracted from the kernel annotation and
155   // obtain a pointer to the first byte in it.
156   auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
157       launch_op, launch_op.getKernelModuleName());
158   assert(kernel_module && "expected a kernel module");
159 
160   auto binary_attr =
161       kernel_module->getAttrOfType<StringAttr>(gpu_binary_annotation_);
162   if (!binary_attr) {
163     kernel_module.emitOpError()
164         << "missing " << gpu_binary_annotation_ << " attribute";
165     return failure();
166   }
167 
168   // Create a global for the module blob.
169   SmallString<128> name_buffer(kernel_module.getName());
170   name_buffer.append("_blob");
171   Value module_blob =
172       LLVM::createGlobalString(loc, rewriter, name_buffer.str(),
173                                binary_attr.getValue(), LLVM::Linkage::Internal);
174 
175   // Make sure the trailing zero is included in the constant.
176   auto kernel_name = launch_op.getKernelName();
177   SmallString<128> kernel_name_buffer(kernel_name);
178   kernel_name_buffer.push_back('\0');
179 
180   // Create a global for the kernel name.
181   SmallString<128> kernel_name_global_name_buffer;
182   auto kernel_name_global_name =
183       (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
184           .toStringRef(kernel_name_global_name_buffer);
185   auto kernel_name_global =
186       LLVM::createGlobalString(loc, rewriter, kernel_name_global_name,
187                                kernel_name_buffer, LLVM::Linkage::Internal);
188 
189   auto adaptor =
190       gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
191 
192   // The TensorFlow OpKernelContext is the first argument of the surrounding
193   // LLVMFunc.
194   Value context_arg =
195       launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
196   auto kernel_params = generateParamsArray(launch_op, operands, rewriter);
197 
198   auto function = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
199       launch_op, kTfWrapperLibaryLaunchHelperName);
200   if (!function) {
201     PatternRewriter::InsertionGuard guard(rewriter);
202     auto function_type = LLVM::LLVMFunctionType::get(
203         llvm_void_type_,
204         {
205             llvm_pointer_type_,         /* void* context */
206             llvm_pointer_type_,         /* void* module_blob */
207             llvm_pointer_type_,         /* void* function_name */
208             llvm_intptr_type_,          /* intptr_t grid_x_dim */
209             llvm_intptr_type_,          /* intptr_t grid_y_dim */
210             llvm_intptr_type_,          /* intptr_t grid_z_dim */
211             llvm_intptr_type_,          /* intptr_t block_x_dim */
212             llvm_intptr_type_,          /* intptr_t block_y_dim */
213             llvm_intptr_type_,          /* intptr_t block_z_dim */
214             llvm_pointer_pointer_type_, /* void **kernel_params */
215         });
216     rewriter.setInsertionPointToStart(
217         launch_op->getParentOfType<ModuleOp>().getBody());
218     function = rewriter.create<LLVM::LLVMFuncOp>(
219         loc, kTfWrapperLibaryLaunchHelperName, function_type);
220   }
221   rewriter.create<LLVM::CallOp>(
222       loc, llvm_void_type_, rewriter.getSymbolRefAttr(function),
223       ArrayRef<Value>{
224           context_arg, module_blob, kernel_name_global, adaptor.gridSizeX(),
225           adaptor.gridSizeY(), adaptor.gridSizeZ(), adaptor.blockSizeX(),
226           adaptor.blockSizeY(), adaptor.blockSizeZ(), kernel_params});
227 
228   rewriter.eraseOp(launch_op);
229   return success();
230 }
231 
232 class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const233   void getDependentDialects(DialectRegistry &registry) const override {
234     registry.insert<LLVM::LLVMDialect>();
235   }
236 
237  public:
TFKernelToLLVMPass(StringRef blob_annotation)238   explicit TFKernelToLLVMPass(StringRef blob_annotation) {
239     if (!blob_annotation.empty()) {
240       blob_annotation_ = blob_annotation.str();
241     }
242   }
243 
runOnOperation()244   void runOnOperation() override {
245     ModuleOp m = getOperation();
246 
247     // Populate type conversions.
248     MLIRContext *ctx = m.getContext();
249     LLVMTypeConverter type_converter(ctx);
250     type_converter.addConversion([&](tf_framework::OpKernelContextType type) {
251       return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
252     });
253 
254     // Populate patterns.
255     OwningRewritePatternList patterns;
256 
257     populateStdExpandOpsPatterns(ctx, patterns);
258     populateStdToLLVMConversionPatterns(type_converter, patterns);
259     populateComplexToLLVMConversionPatterns(type_converter, patterns);
260     tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
261                                                               &patterns);
262     patterns.insert<ConvertLaunchFuncOpToTfRuntimeCallPattern>(
263         type_converter, blob_annotation_);
264     // Set target.
265     ConversionTarget target(*ctx);
266     target.addLegalDialect<LLVM::LLVMDialect>();
267     target.addIllegalDialect<StandardOpsDialect, complex::ComplexDialect,
268                              gpu::GPUDialect, tf_framework::TFFrameworkDialect,
269                              math::MathDialect>();
270     target.addIllegalOp<LLVM::DialectCastOp>();
271     // Mark modules as legal.
272     target.addLegalOp<ModuleOp, ModuleTerminatorOp, gpu::GPUModuleOp>();
273     // Do not look into gpu modules, only consider host-side.
274     target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
275 
276     if (failed(applyFullConversion(m, target, std::move(patterns)))) {
277       signalPassFailure();
278     }
279 
280     // Finally, strip the GPU modules, as they are no longer needed.
281     for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
282       op.erase();
283     }
284   }
285 };
286 
287 }  // namespace
288 
CreateTFKernelToLLVMPass(StringRef blob_annotation)289 std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass(
290     StringRef blob_annotation) {
291   return std::make_unique<TFKernelToLLVMPass>(blob_annotation);
292 }
293 
294 }  // namespace transforms
295 }  // namespace kernel_gen
296 }  // namespace mlir
297