1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
17
18 #include "llvm/Linker/Linker.h"
19 #include "llvm/Transforms/IPO/Internalize.h"
20 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
21 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
22 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassManager.h" // from @llvm-project
28 #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
29 #include "mlir/Transforms/Passes.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
31
32 namespace xla {
33 namespace cpu {
34 namespace {
35
36 // Lower an MLIR module to an LLVM module.
MakeLLVMModule(mlir::OwningOpRef<mlir::ModuleOp> module,llvm::LLVMContext * context)37 std::unique_ptr<llvm::Module> MakeLLVMModule(
38 mlir::OwningOpRef<mlir::ModuleOp> module, llvm::LLVMContext *context) {
39 // When set, the LLVM backend will be allowed to reassociate floating-point
40 // reductions, which enables much more efficient "horizontal" SIMD
41 // implementations.
42 // TODO(kramerb): link this to the right option, command line flag, etc.
43 constexpr bool kReassociateFPReductions = true;
44
45 mlir::PassManager manager(module->getContext(),
46 mlir::OpPassManager::Nesting::Implicit);
47 manager.addPass(mlir::createConvertLinalgToLoopsPass());
48 manager.addPass(mlir::createLowerAffinePass());
49 manager.addPass(mlir::createConvertSCFToCFPass());
50 manager.addPass(mlir::createConvertVectorToLLVMPass(
51 mlir::LowerVectorToLLVMOptions().enableReassociateFPReductions(
52 kReassociateFPReductions)));
53 CHECK(succeeded(manager.run(*module)));
54 return mlir::translateModuleToLLVMIR(*module, *context);
55 }
56
57 // Get arguments to pass a memref to an mlir function.
BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value * > * args,llvm::IRBuilder<> * b,const Shape & opShape,llvm::Value * op_val)58 void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
59 llvm::IRBuilder<> *b, const Shape &opShape,
60 llvm::Value *op_val) {
61 llvm::Type *ty = op_val->getType();
62 if (!ty->isOpaquePointerTy()) {
63 while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
64 ty->getNonOpaquePointerElementType())) {
65 ty = aty->getElementType()->getPointerTo();
66 }
67 }
68 op_val = b->CreateBitCast(op_val, ty);
69
70 args->push_back(op_val); // Allocated pointer.
71 args->push_back(op_val); // Aligned pointer.
72 args->push_back(b->getInt64(0)); // Offset.
73
74 // Sizes.
75 for (int64_t dim : opShape.dimensions()) {
76 args->push_back(b->getInt64(dim));
77 }
78
79 int64_t accumulated_stride = 1;
80 llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
81 for (int64_t dim : LayoutUtil::MinorToMajor(opShape)) {
82 strides[dim] = accumulated_stride;
83 accumulated_stride *= opShape.dimensions(dim);
84 }
85
86 // Strides.
87 for (int64_t stride : strides) {
88 args->push_back(b->getInt64(stride));
89 }
90 }
91 } // namespace
92
EmitMlirFuncAndCall(mlir::MLIRContext * context,llvm::IRBuilder<> * b,const Shape & result_shape,llvm::ArrayRef<Shape> operand_shapes,llvm::Value * result_ptr,llvm::ArrayRef<llvm::Value * > operand_ptrs,llvm::StringRef func_name,llvm::function_ref<void (mlir::OpBuilder *,mlir::func::FuncOp)> emitter)93 Status EmitMlirFuncAndCall(
94 mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
95 llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
96 llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
97 llvm::function_ref<void(mlir::OpBuilder *, mlir::func::FuncOp)> emitter) {
98 llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
99 mlir::Builder mlir_builder(context);
100
101 // Get memref types for the inputs and output.
102 TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
103 result_shape, mlir_builder));
104 std::vector<mlir::Type> operand_types = {ret_memref};
105 for (int i = 0; i != operand_shapes.size(); ++i) {
106 TF_ASSIGN_OR_RETURN(
107 mlir::Type op_memref,
108 ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
109 operand_types.push_back(op_memref);
110 }
111
112 // Create the function an call the emission callback.
113 mlir::Location loc = mlir::UnknownLoc::get(context);
114 auto function = mlir::func::FuncOp::create(
115 loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
116 function.addEntryBlock();
117 mlir::OwningOpRef<mlir::ModuleOp> mlir_module = mlir::ModuleOp::create(loc);
118 mlir_module->push_back(function);
119 mlir::OpBuilder op_builder(&function.getBody());
120 emitter(&op_builder, function);
121
122 // Now link it all into the main LLVM module.
123 auto mlir_llvm_module =
124 MakeLLVMModule(std::move(mlir_module), &b->getContext());
125 mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
126 llvm::Linker::linkModules(
127 *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
128 [](llvm::Module &M, const llvm::StringSet<> &GVS) {
129 llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
130 return !GV.hasName() || (GVS.count(GV.getName()) == 0);
131 });
132 });
133
134 // And leave behind a call to the function generated by MLIR.
135 llvm::Function *func = llvm_module->getFunction(func_name);
136 llvm::SmallVector<llvm::Value *, 4> op_vals;
137 BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
138 for (int i = 0; i != operand_shapes.size(); ++i) {
139 BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
140 }
141 b->CreateCall(func, op_vals);
142
143 return OkStatus();
144 }
145
146 } // namespace cpu
147 } // namespace xla
148