1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdexcept>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project
20 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
21 #include "mlir/Conversion/LLVMCommon/Pattern.h" // from @llvm-project
22 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project
23 #include "mlir/Conversion/MathToLibm/MathToLibm.h" // from @llvm-project
24 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project
25 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
26 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
27 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
28 #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
29 #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
30 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
31 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
32 #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
33 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
34 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
36 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
38 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
40
41 namespace mlir {
42 namespace kernel_gen {
43 namespace transforms {
44 namespace {
45
46 constexpr StringRef kTfWrapperLibaryLaunchHelperName =
47 "_mlir_ciface_tf_launch_kernel";
48
49 #define GEN_PASS_CLASSES
50 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
51
52 /// A rewrite patter to convert gpu.launch_func operations into a runtime call
53 /// for the TensorFlow runtime.
54 class ConvertLaunchFuncOpToTfRuntimeCallPattern
55 : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
56 public:
ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter & type_converter,StringRef gpu_binary_annotation)57 ConvertLaunchFuncOpToTfRuntimeCallPattern(LLVMTypeConverter &type_converter,
58 StringRef gpu_binary_annotation)
59 : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
60 gpu_binary_annotation_(gpu_binary_annotation) {}
61
62 private:
63 Value generateParamsArray(gpu::LaunchFuncOp launch_op,
64 ArrayRef<Value> operands, OpBuilder &builder) const;
65 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
66 Location loc, OpBuilder &builder) const;
67
68 LogicalResult matchAndRewrite(
69 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
70 ConversionPatternRewriter &rewriter) const override;
71
72 MLIRContext *context_ = &this->getTypeConverter()->getContext();
73
74 Type llvm_void_type_ = LLVM::LLVMVoidType::get(context_);
75 Type llvm_pointer_type_ =
76 LLVM::LLVMPointerType::get(IntegerType::get(context_, 8));
77 Type llvm_pointer_pointer_type_ =
78 LLVM::LLVMPointerType::get(llvm_pointer_type_);
79 Type llvm_int8_type_ = IntegerType::get(context_, 8);
80 Type llvm_int32_type_ = IntegerType::get(context_, 32);
81 Type llvm_int64_type_ = IntegerType::get(context_, 64);
82 Type llvm_intptr_type_ = IntegerType::get(
83 context_, this->getTypeConverter()->getPointerBitwidth(0));
84
85 llvm::SmallString<32> gpu_binary_annotation_;
86 };
87
88 // Creates a struct containing all kernel parameters on the stack and returns
89 // an array of type-erased pointers to the fields of the struct. The array can
90 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
91 // The generated code is essentially as follows:
92 //
93 // %struct = alloca(sizeof(struct { Parameters... }))
94 // %array = alloca(NumParameters * sizeof(void *))
95 // for (i : [0, NumParameters))
96 // %fieldPtr = llvm.getelementptr %struct[0, i]
97 // llvm.store parameters[i], %fieldPtr
98 // %elementPtr = llvm.getelementptr %array[i]
99 // llvm.store %fieldPtr, %elementPtr
100 // return %array
generateParamsArray(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,OpBuilder & builder) const101 Value ConvertLaunchFuncOpToTfRuntimeCallPattern::generateParamsArray(
102 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
103 OpBuilder &builder) const {
104 auto loc = launch_op.getLoc();
105 auto num_kernel_operands = launch_op.getNumKernelOperands();
106 auto arguments = getTypeConverter()->promoteOperands(
107 loc, launch_op.getOperands().take_back(num_kernel_operands),
108 operands.take_back(num_kernel_operands), builder);
109 auto num_arguments = arguments.size();
110 SmallVector<Type, 4> argument_types;
111 argument_types.reserve(num_arguments);
112 for (auto argument : arguments) argument_types.push_back(argument.getType());
113 auto struct_type = LLVM::LLVMStructType::getNewIdentified(
114 context_, StringRef(), argument_types);
115 auto one = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
116 builder.getI32IntegerAttr(1));
117 auto struct_ptr = builder.create<LLVM::AllocaOp>(
118 loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0);
119 auto array_size = builder.create<LLVM::ConstantOp>(
120 loc, llvm_int32_type_, builder.getI32IntegerAttr(num_arguments));
121 auto array_ptr = builder.create<LLVM::AllocaOp>(
122 loc, llvm_pointer_pointer_type_, array_size, /*alignment=*/0);
123 auto zero = builder.create<LLVM::ConstantOp>(loc, llvm_int32_type_,
124 builder.getI32IntegerAttr(0));
125 for (auto en : llvm::enumerate(arguments)) {
126 auto index = builder.create<LLVM::ConstantOp>(
127 loc, llvm_int32_type_, builder.getI32IntegerAttr(en.index()));
128 auto field_ptr = builder.create<LLVM::GEPOp>(
129 loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr,
130 ArrayRef<Value>{zero, index.getResult()});
131 builder.create<LLVM::StoreOp>(loc, en.value(), field_ptr);
132 auto element_ptr = builder.create<LLVM::GEPOp>(
133 loc, llvm_pointer_pointer_type_, array_ptr, index.getResult());
134 auto casted =
135 builder.create<LLVM::BitcastOp>(loc, llvm_pointer_type_, field_ptr);
136 builder.create<LLVM::StoreOp>(loc, casted, element_ptr);
137 }
138 return array_ptr;
139 }
140
141 // Emits LLVM IR to launch a kernel function. Expects the module that contains
142 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
143 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
144 //
145 // %0 = call %binarygetter
146 // %1 = <pointer to kernel function name>
147 // %2 = <see generateParamsArray>
148 // call %tfLaunchKernel(%ctx, %0, %1, <launch_op operands 0..5>, %2)
matchAndRewrite(gpu::LaunchFuncOp launch_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const149 LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite(
150 gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
151 ConversionPatternRewriter &rewriter) const {
152 if (!launch_op.asyncDependencies().empty() || launch_op.asyncToken()) {
153 return rewriter.notifyMatchFailure(
154 launch_op, "Cannot convert with async dependency or result.");
155 }
156
157 Location loc = launch_op.getLoc();
158
159 // Create an LLVM global with CUBIN extracted from the kernel annotation and
160 // obtain a pointer to the first byte in it.
161 auto kernel_module = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
162 launch_op, launch_op.getKernelModuleName());
163 assert(kernel_module && "expected a kernel module");
164
165 auto binary_attr =
166 kernel_module->getAttrOfType<StringAttr>(gpu_binary_annotation_);
167 if (!binary_attr) {
168 kernel_module.emitOpError()
169 << "missing " << gpu_binary_annotation_ << " attribute";
170 return failure();
171 }
172
173 // Create a global for the module blob.
174 SmallString<128> name_buffer(kernel_module.getName());
175 name_buffer.append("_blob");
176 Value module_blob =
177 LLVM::createGlobalString(loc, rewriter, name_buffer.str(),
178 binary_attr.getValue(), LLVM::Linkage::Internal);
179
180 // Make sure the trailing zero is included in the constant.
181 auto kernel_name = launch_op.getKernelName();
182 SmallString<128> kernel_name_buffer(kernel_name);
183 kernel_name_buffer.push_back('\0');
184
185 // Create a global for the kernel name.
186 SmallString<128> kernel_name_global_name_buffer;
187 auto kernel_name_global_name =
188 (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
189 .toStringRef(kernel_name_global_name_buffer);
190 auto kernel_name_global =
191 LLVM::createGlobalString(loc, rewriter, kernel_name_global_name,
192 kernel_name_buffer, LLVM::Linkage::Internal);
193
194 auto adaptor =
195 gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
196
197 // The TensorFlow OpKernelContext is the first argument of the surrounding
198 // LLVMFunc.
199 Value context_arg =
200 launch_op->getParentOfType<LLVM::LLVMFuncOp>().getArgument(0);
201 auto kernel_params = generateParamsArray(launch_op, operands, rewriter);
202
203 auto function = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
204 launch_op, kTfWrapperLibaryLaunchHelperName);
205 if (!function) {
206 PatternRewriter::InsertionGuard guard(rewriter);
207 auto function_type = LLVM::LLVMFunctionType::get(
208 llvm_void_type_,
209 {
210 llvm_pointer_type_, /* void* context */
211 llvm_pointer_type_, /* void* module_blob */
212 llvm_pointer_type_, /* void* function_name */
213 llvm_intptr_type_, /* intptr_t grid_x_dim */
214 llvm_intptr_type_, /* intptr_t grid_y_dim */
215 llvm_intptr_type_, /* intptr_t grid_z_dim */
216 llvm_intptr_type_, /* intptr_t block_x_dim */
217 llvm_intptr_type_, /* intptr_t block_y_dim */
218 llvm_intptr_type_, /* intptr_t block_z_dim */
219 llvm_pointer_pointer_type_, /* void **kernel_params */
220 });
221 rewriter.setInsertionPointToStart(
222 launch_op->getParentOfType<ModuleOp>().getBody());
223 function = rewriter.create<LLVM::LLVMFuncOp>(
224 loc, kTfWrapperLibaryLaunchHelperName, function_type);
225 }
226 rewriter.create<LLVM::CallOp>(
227 loc, TypeRange(), rewriter.getSymbolRefAttr(function),
228 ArrayRef<Value>{
229 context_arg, module_blob, kernel_name_global, adaptor.gridSizeX(),
230 adaptor.gridSizeY(), adaptor.gridSizeZ(), adaptor.blockSizeX(),
231 adaptor.blockSizeY(), adaptor.blockSizeZ(), kernel_params});
232
233 rewriter.eraseOp(launch_op);
234 return success();
235 }
236
237 class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const238 void getDependentDialects(DialectRegistry ®istry) const override {
239 registry.insert<LLVM::LLVMDialect>();
240 }
241
242 public:
TFKernelToLLVMPass(StringRef blob_annotation)243 explicit TFKernelToLLVMPass(StringRef blob_annotation) {
244 if (!blob_annotation.empty()) {
245 blob_annotation_ = blob_annotation.str();
246 }
247 }
248
runOnOperation()249 void runOnOperation() override {
250 ModuleOp m = getOperation();
251
252 // Populate type conversions.
253 MLIRContext *ctx = m.getContext();
254 LLVMTypeConverter type_converter(ctx);
255 type_converter.addConversion([&](tf_framework::OpKernelContextType type) {
256 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
257 });
258 type_converter.addConversion([&](tf_framework::JITCallableType type) {
259 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
260 });
261
262 // Populate patterns.
263 RewritePatternSet patterns(&getContext());
264 populateStdExpandOpsPatterns(patterns);
265 populateMemRefToLLVMConversionPatterns(type_converter, patterns);
266 populateMathToLLVMConversionPatterns(type_converter, patterns);
267 populateStdToLLVMConversionPatterns(type_converter, patterns);
268 populateComplexToLLVMConversionPatterns(type_converter, patterns);
269 populateVectorToLLVMConversionPatterns(type_converter, patterns);
270 populateMathToLibmConversionPatterns(patterns, 0);
271 tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
272 &patterns);
273 patterns.insert<ConvertLaunchFuncOpToTfRuntimeCallPattern>(
274 type_converter, blob_annotation_);
275 // Set target.
276 ConversionTarget target(*ctx);
277 target.addLegalDialect<LLVM::LLVMDialect>();
278 target.addIllegalDialect<StandardOpsDialect, complex::ComplexDialect,
279 gpu::GPUDialect, tf_framework::TFFrameworkDialect,
280 math::MathDialect>();
281 target.addIllegalOp<UnrealizedConversionCastOp>();
282 // Mark modules as legal.
283 target.addLegalOp<ModuleOp, gpu::GPUModuleOp>();
284 // Do not look into gpu modules, only consider host-side.
285 target.markOpRecursivelyLegal<gpu::GPUModuleOp>();
286
287 if (failed(applyFullConversion(m, target, std::move(patterns)))) {
288 signalPassFailure();
289 }
290
291 // Finally, strip the GPU modules, as they are no longer needed.
292 for (auto op : llvm::make_early_inc_range(m.getOps<gpu::GPUModuleOp>())) {
293 op.erase();
294 }
295 }
296 };
297
298 } // namespace
299
CreateTFKernelToLLVMPass(StringRef blob_annotation)300 std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass(
301 StringRef blob_annotation) {
302 return std::make_unique<TFKernelToLLVMPass>(blob_annotation);
303 }
304
305 } // namespace transforms
306 } // namespace kernel_gen
307 } // namespace mlir
308