1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdexcept>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project
20 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
21 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
22 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
23 #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
24 #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
26 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
27 #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
35
36 namespace mlir {
37 namespace kernel_gen {
38 namespace transforms {
39 namespace {
40
41 constexpr StringRef kTfWrapperLibaryLaunchHelperName =
42 "tfKernelGenLaunchKernel";
43
44 #define GEN_PASS_CLASSES
45 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
46
47 /// A rewrite patter to convert gpu.launch_func operations into a runtime call
48 /// for the TensorFlow runtime.
49 class ConvertLaunchFuncOpToTfRuntimeCallPattern
50 : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
51 public:
ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter & type_converter,StringRef gpu_binary_annotation)52 ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter &type_converter,
53 StringRef gpu_binary_annotation)
54 : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
55 gpu_binary_annotation_(gpu_binary_annotation) {}
56
57 private:
58 Value generateParamsArray(gpu::LaunchFuncOp launch_op,
59 ArrayRef<Value> operands, OpBuilder &builder) const;
60 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
61 Location loc, OpBuilder &builder) const;
62
63 LogicalResult matchAndRewrite(
64 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
65 ConversionPatternRewriter &rewriter) const override;
66
67 MLIRContext *context_ = &this->getTypeConverter()->getContext();
68
69 Type llvm_void_type_ = LLVM::LLVMVoidType::get(context_);
70 Type llvm_pointer_type_ =
71 LLVM::LLVMPointerType::get(IntegerType::get(context_, 8));
72 Type llvm_pointer_pointer_type_ =
73 LLVM::LLVMPointerType::get(llvm_pointer_type_);
74 Type llvm_int8_type_ = IntegerType::get(context_, 8);
75 Type llvm_int32_type_ = IntegerType::get(context_, 32);
76 Type llvm_int64_type_ = IntegerType::get(context_, 64);
77 Type llvm_intptr_type_ = IntegerType::get(
78 context_, this->getTypeConverter()->getPointerBitwidth(0));
79
80 llvm::SmallString<32> gpu_binary_annotation_;
81 };
82
83 // Creates a struct containing all kernel parameters on the stack and returns
84 // an array of type-erased pointers to the fields of the struct. The array can
85 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
86 // The generated code is essentially as follows:
87 //
88 // %struct = alloca(sizeof(struct { Parameters... }))
89 // %array = alloca(NumParameters * sizeof(void *))
90 // for (i : [0, NumParameters))
91 // %fieldPtr = llvm.getelementptr %struct[0, i]
92 // llvm.store parameters[i], %fieldPtr
93 // %elementPtr = llvm.getelementptr %array[i]
94 // llvm.store %fieldPtr, %elementPtr
95 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,OpBuilder & builder) const96 Value ConvertLaunchFuncOpToTfRuntimeCallPattern::generateParamsArray(
97 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
98 OpBuilder &builder) const {
99 auto loc = launch_op.getLoc();
100 auto num_kernel_operands = launch_op.getNumKernelOperands();
101 auto arguments = getTypeConverter()->promoteOperands(
102 loc, launch_op.getOperands().take_back(num_kernel_operands),
103 operands.take_back(num_kernel_operands), builder);
104 auto num_arguments = arguments.size();
105 SmallVector<Type, 4> argument_types;
106 argument_types.reserve(num_arguments);
107 for (auto argument : arguments) argument_types.push_back(argument.getType());
108 auto struct_type = LLVM::LLVMStructType::getNewIdentified(
109 context_, StringRef(), argument_types);
110 auto one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
111 builder.getI32IntegerAttr(1));
112 auto struct_ptr = builder.create<LLVM::AllocaOp>(
113 loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
114 auto array_size = builder.create<LLVM::ConstantOp>(
115 loc, llvm_int32_type_, builder.getI32IntegerAttr(num_arguments));
116 auto array_ptr = builder.create<LLVM::AllocaOp>(
117 loc, llvm_pointer_pointer_type_, array_size, /*alignment=*/0);
118 auto zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
119 builder.getI32IntegerAttr(0));
120 for (auto en : llvm::enumerate(arguments)) {
121 auto index = builder.create<LLVM::ConstantOp>(
122 loc, llvm_int32_type_, builder.getI32IntegerAttr(en.index()));
123 auto field_ptr = builder.create<LLVM::GEPOp>(
124 loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
125 ArrayRef<Value>{zero, index.getResult()});
126 builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
127 auto element_ptr = builder.create<LLVM::GEPOp>(
128 loc, llvm_pointer_pointer_type_, array_ptr, index.getResult());
129 auto casted =
130 builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type_, field_ptr);
131 builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
132 }
133 return array_ptr;
134 }
135
136 // Emits LLVM IR to launch a kernel function. Expects the module that contains
137 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
138 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
139 //
140 // %0 = call %binarygetter
141 // %1 = <pointer to kernel function name>
142 // %2 = <see generateParamsArray>
143 // call %tfLaunchKernel(%ctx, %0, %1, <launch_op operands 0..5>, %2)
matchAndRewrite(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const144 LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite(
145 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
146 ConversionPatternRewriter &rewriter) const {
147 if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
148 return rewriter.notifyMatchFailure(
149 launch_op, "Cannot convert with async dependency or result.");
150 }
151
152 Location loc = launch_op.getLoc();
153
154 // Create an LLVM global with CUBIN extracted from the kernel annotation and
155 // obtain a pointer to the first byte in it.
156 auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
157 launch_op, launch_op.getKernelModuleName());
158 assert(kernel_module && "expected a kernel module");
159
160 auto binary_attr =
161 kernel_module->getAttrOfType<StringAttr>(gpu_binary_annotation_);
162 if (!binary_attr) {
163 kernel_module.emitOpError()
164 << "missing " << gpu_binary_annotation_ << " attribute";
165 return failure();
166 }
167
168 // Create a global for the module blob.
169 SmallString<128> name_buffer(kernel_module.getName());
170 name_buffer.append("_blob");
171 Value module_blob =
172 LLVM::createGlobalString(loc, rewriter, name_buffer.str(),
173 binary_attr.getValue(), LLVM::Linkage::Internal);
174
175 // Make sure the trailing zero is included in the constant.
176 auto kernel_name = launch_op.getKernelName();
177 SmallString<128> kernel_name_buffer(kernel_name);
178 kernel_name_buffer.push_back('\0');
179
180 // Create a global for the kernel name.
181 SmallString<128> kernel_name_global_name_buffer;
182 auto kernel_name_global_name =
183 (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
184 .toStringRef(kernel_name_global_name_buffer);
185 auto kernel_name_global =
186 LLVM::createGlobalString(loc, rewriter, kernel_name_global_name,
187 kernel_name_buffer, LLVM::Linkage::Internal);
188
189 auto adaptor =
190 gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
191
192 // The TensorFlow OpKernelContext is the first argument of the surrounding
193 // LLVMFunc.
194 Value context_arg =
195 launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
196 auto kernel_params = generateParamsArray(launch_op, operands, rewriter);
197
198 auto function = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
199 launch_op, kTfWrapperLibaryLaunchHelperName);
200 if (!function) {
201 PatternRewriter::InsertionGuard guard(rewriter);
202 auto function_type = LLVM::LLVMFunctionType::get(
203 llvm_void_type_,
204 {
205 llvm_pointer_type_, /* void* context */
206 llvm_pointer_type_, /* void* module_blob */
207 llvm_pointer_type_, /* void* function_name */
208 llvm_intptr_type_, /* intptr_t grid_x_dim */
209 llvm_intptr_type_, /* intptr_t grid_y_dim */
210 llvm_intptr_type_, /* intptr_t grid_z_dim */
211 llvm_intptr_type_, /* intptr_t block_x_dim */
212 llvm_intptr_type_, /* intptr_t block_y_dim */
213 llvm_intptr_type_, /* intptr_t block_z_dim */
214 llvm_pointer_pointer_type_, /* void **kernel_params */
215 });
216 rewriter.setInsertionPointToStart(
217 launch_op->getParentOfType<ModuleOp>().getBody());
218 function = rewriter.create<LLVM::LLVMFuncOp>(
219 loc, kTfWrapperLibaryLaunchHelperName, function_type);
220 }
221 rewriter.create<LLVM::CallOp>(
222 loc, llvm_void_type_, rewriter.getSymbolRefAttr(function),
223 ArrayRef<Value>{
224 context_arg, module_blob, kernel_name_global, adaptor.gridSizeX(),
225 adaptor.gridSizeY(), adaptor.gridSizeZ(), adaptor.blockSizeX(),
226 adaptor.blockSizeY(), adaptor.blockSizeZ(), kernel_params});
227
228 rewriter.eraseOp(launch_op);
229 return success();
230 }
231
232 class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const233 void getDependentDialects(DialectRegistry ®istry) const override {
234 registry.insert<LLVM::LLVMDialect>();
235 }
236
237 public:
TFKernelToLLVMPass(StringRef blob_annotation)238 explicit TFKernelToLLVMPass(StringRef blob_annotation) {
239 if (!blob_annotation.empty()) {
240 blob_annotation_ = blob_annotation.str();
241 }
242 }
243
runOnOperation()244 void runOnOperation() override {
245 ModuleOp m = getOperation();
246
247 // Populate type conversions.
248 MLIRContext *ctx = m.getContext();
249 LLVMTypeConverter type_converter(ctx);
250 type_converter.addConversion([&](tf_framework::OpKernelContextType type) {
251 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
252 });
253
254 // Populate patterns.
255 OwningRewritePatternList patterns;
256
257 populateStdExpandOpsPatterns(ctx, patterns);
258 populateStdToLLVMConversionPatterns(type_converter, patterns);
259 populateComplexToLLVMConversionPatterns(type_converter, patterns);
260 tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
261 &patterns);
262 patterns.insert<ConvertLaunchFuncOpToTfRuntimeCallPattern>(
263 type_converter, blob_annotation_);
264 // Set target.
265 ConversionTarget target(*ctx);
266 target.addLegalDialect<LLVM::LLVMDialect>();
267 target.addIllegalDialect<StandardOpsDialect, complex::ComplexDialect,
268 gpu::GPUDialect, tf_framework::TFFrameworkDialect,
269 math::MathDialect>();
270 target.addIllegalOp<LLVM::DialectCastOp>();
271 // Mark modules as legal.
272 target.addLegalOp<ModuleOp, ModuleTerminatorOp, gpu::GPUModuleOp>();
273 // Do not look into gpu modules, only consider host-side.
274 target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
275
276 if (failed(applyFullConversion(m, target, std::move(patterns)))) {
277 signalPassFailure();
278 }
279
280 // Finally, strip the GPU modules, as they are no longer needed.
281 for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
282 op.erase();
283 }
284 }
285 };
286
287 } // namespace
288
CreateTFKernelToLLVMPass(StringRef blob_annotation)289 std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass(
290 StringRef blob_annotation) {
291 return std::make_unique<TFKernelToLLVMPass>(blob_annotation);
292 }
293
294 } // namespace transforms
295 } // namespace kernel_gen
296 } // namespace mlir
297