• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
17 
18 #include "llvm/Linker/Linker.h"
19 #include "llvm/Transforms/IPO/Internalize.h"
20 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"  // from @llvm-project
21 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"  // from @llvm-project
22 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"  // from @llvm-project
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassManager.h"  // from @llvm-project
28 #include "mlir/Target/LLVMIR/Export.h"  // from @llvm-project
29 #include "mlir/Transforms/Passes.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
31 
32 namespace xla {
33 namespace cpu {
34 namespace {
35 
36 // Lower an MLIR module to an LLVM module.
MakeLLVMModule(mlir::OwningOpRef<mlir::ModuleOp> module,llvm::LLVMContext * context)37 std::unique_ptr<llvm::Module> MakeLLVMModule(
38     mlir::OwningOpRef<mlir::ModuleOp> module, llvm::LLVMContext *context) {
39   // When set, the LLVM backend will be allowed to reassociate floating-point
40   // reductions, which enables much more efficient "horizontal" SIMD
41   // implementations.
42   // TODO(kramerb): link this to the right option, command line flag, etc.
43   constexpr bool kReassociateFPReductions = true;
44 
45   mlir::PassManager manager(module->getContext(),
46                             mlir::OpPassManager::Nesting::Implicit);
47   manager.addPass(mlir::createConvertLinalgToLoopsPass());
48   manager.addPass(mlir::createLowerAffinePass());
49   manager.addPass(mlir::createConvertSCFToCFPass());
50   manager.addPass(mlir::createConvertVectorToLLVMPass(
51       mlir::LowerVectorToLLVMOptions().enableReassociateFPReductions(
52           kReassociateFPReductions)));
53   CHECK(succeeded(manager.run(*module)));
54   return mlir::translateModuleToLLVMIR(*module, *context);
55 }
56 
57 // Get arguments to pass a memref to an mlir function.
BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value * > * args,llvm::IRBuilder<> * b,const Shape & opShape,llvm::Value * op_val)58 void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
59                         llvm::IRBuilder<> *b, const Shape &opShape,
60                         llvm::Value *op_val) {
61   llvm::Type *ty = op_val->getType();
62   if (!ty->isOpaquePointerTy()) {
63     while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
64                ty->getNonOpaquePointerElementType())) {
65       ty = aty->getElementType()->getPointerTo();
66     }
67   }
68   op_val = b->CreateBitCast(op_val, ty);
69 
70   args->push_back(op_val);          // Allocated pointer.
71   args->push_back(op_val);          // Aligned pointer.
72   args->push_back(b->getInt64(0));  // Offset.
73 
74   // Sizes.
75   for (int64_t dim : opShape.dimensions()) {
76     args->push_back(b->getInt64(dim));
77   }
78 
79   int64_t accumulated_stride = 1;
80   llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
81   for (int64_t dim : LayoutUtil::MinorToMajor(opShape)) {
82     strides[dim] = accumulated_stride;
83     accumulated_stride *= opShape.dimensions(dim);
84   }
85 
86   // Strides.
87   for (int64_t stride : strides) {
88     args->push_back(b->getInt64(stride));
89   }
90 }
91 }  // namespace
92 
EmitMlirFuncAndCall(mlir::MLIRContext * context,llvm::IRBuilder<> * b,const Shape & result_shape,llvm::ArrayRef<Shape> operand_shapes,llvm::Value * result_ptr,llvm::ArrayRef<llvm::Value * > operand_ptrs,llvm::StringRef func_name,llvm::function_ref<void (mlir::OpBuilder *,mlir::func::FuncOp)> emitter)93 Status EmitMlirFuncAndCall(
94     mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
95     llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
96     llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
97     llvm::function_ref<void(mlir::OpBuilder *, mlir::func::FuncOp)> emitter) {
98   llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
99   mlir::Builder mlir_builder(context);
100 
101   // Get memref types for the inputs and output.
102   TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
103                                                  result_shape, mlir_builder));
104   std::vector<mlir::Type> operand_types = {ret_memref};
105   for (int i = 0; i != operand_shapes.size(); ++i) {
106     TF_ASSIGN_OR_RETURN(
107         mlir::Type op_memref,
108         ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
109     operand_types.push_back(op_memref);
110   }
111 
112   // Create the function an call the emission callback.
113   mlir::Location loc = mlir::UnknownLoc::get(context);
114   auto function = mlir::func::FuncOp::create(
115       loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
116   function.addEntryBlock();
117   mlir::OwningOpRef<mlir::ModuleOp> mlir_module = mlir::ModuleOp::create(loc);
118   mlir_module->push_back(function);
119   mlir::OpBuilder op_builder(&function.getBody());
120   emitter(&op_builder, function);
121 
122   // Now link it all into the main LLVM module.
123   auto mlir_llvm_module =
124       MakeLLVMModule(std::move(mlir_module), &b->getContext());
125   mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
126   llvm::Linker::linkModules(
127       *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
128       [](llvm::Module &M, const llvm::StringSet<> &GVS) {
129         llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
130           return !GV.hasName() || (GVS.count(GV.getName()) == 0);
131         });
132       });
133 
134   // And leave behind a call to the function generated by MLIR.
135   llvm::Function *func = llvm_module->getFunction(func_name);
136   llvm::SmallVector<llvm::Value *, 4> op_vals;
137   BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
138   for (int i = 0; i != operand_shapes.size(); ++i) {
139     BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
140   }
141   b->CreateCall(func, op_vals);
142 
143   return OkStatus();
144 }
145 
146 }  // namespace cpu
147 }  // namespace xla
148