• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/Support/FormatVariadic.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // from @llvm-project
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
26 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
27 
28 namespace mlir {
29 namespace kernel_gen {
30 namespace tf_framework {
31 namespace {
32 
33 using LLVM::LLVMFuncOp;
34 
35 static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
36 static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
37 static constexpr StringRef kCInterfaceReportError =
38     "_mlir_ciface_tf_report_error";
39 
40 /// Base class for patterns converting TF Framework ops to function calls.
41 template <typename OpTy>
42 class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
43  public:
44   using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
45 
46   // Attempts to find function symbol in the module, adds it if not found.
getOrInsertTFFunction(PatternRewriter & rewriter,Operation * op) const47   FlatSymbolRefAttr getOrInsertTFFunction(PatternRewriter &rewriter,
48                                           Operation *op) const {
49     ModuleOp module = op->getParentOfType<ModuleOp>();
50     StringRef tf_func_name = GetFuncName();
51     auto tf_func = module.lookupSymbol<LLVMFuncOp>(tf_func_name);
52     if (!tf_func) {
53       OpBuilder::InsertionGuard guard(rewriter);
54       rewriter.setInsertionPointToStart(module.getBody());
55       auto func_type = GetFuncType();
56       tf_func = rewriter.create<LLVMFuncOp>(rewriter.getUnknownLoc(),
57                                             tf_func_name, func_type);
58     }
59     return SymbolRefAttr::get(rewriter.getContext(), tf_func_name);
60   }
61 
62  protected:
63   virtual StringRef GetFuncName() const = 0;
64   virtual Type GetFuncType() const = 0;
65 };
66 
67 class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
68  public:
69   using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
70 
matchAndRewrite(TFAllocOp tf_alloc_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const71   LogicalResult matchAndRewrite(
72       TFAllocOp tf_alloc_op, ArrayRef<Value> operands,
73       ConversionPatternRewriter &rewriter) const override {
74     mlir::Operation *op = tf_alloc_op.getOperation();
75     Location loc = op->getLoc();
76     TFAllocOp::Adaptor transformed(operands);
77 
78     MemRefType memref_type = tf_alloc_op.getType();
79 
80     // Get memref descriptor sizes.
81     SmallVector<Value, 4> sizes;
82     SmallVector<Value, 4> strides;
83     Value sizeBytes;
84     getMemRefDescriptorSizes(loc, memref_type,
85                              llvm::to_vector<4>(transformed.dyn_sizes()),
86                              rewriter, sizes, strides, sizeBytes);
87     // Get number of elements.
88     Value num_elements = getNumElements(loc, sizes, rewriter);
89     // Get element size.
90     Value element_size =
91         getSizeInBytes(loc, memref_type.getElementType(), rewriter);
92 
93     // Convert `output_index` or set it to -1 if the attribute is missing.
94     Type llvmInt32Type = IntegerType::get(rewriter.getContext(), 32);
95     Value output_index = rewriter.create<LLVM::ConstantOp>(
96         loc, llvmInt32Type,
97         rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
98                                        ? tf_alloc_op.output_index().getValue()
99                                        : -1));
100 
101     // Convert `candidate_input_indices`.
102     auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray(
103         loc, tf_alloc_op.input_indices(), &rewriter);
104 
105     // Insert function call.
106     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
107     Value allocated_byte_ptr =
108         rewriter
109             .create<LLVM::CallOp>(
110                 loc, getVoidPtrType(), tf_func_ref,
111                 llvm::makeArrayRef({transformed.ctx(), num_elements,
112                                     element_size, output_index,
113                                     candidates_count_and_ptr.first,
114                                     candidates_count_and_ptr.second}))
115             .getResult(0);
116 
117     MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
118         loc, rewriter, memref_type, allocated_byte_ptr, sizes);
119 
120     // Return the final value of the descriptor.
121     rewriter.replaceOp(op, {memRefDescriptor});
122     return success();
123   }
124 
125  protected:
GetFuncName() const126   StringRef GetFuncName() const override { return kCInterfaceAlloc; }
127 
GetFuncType() const128   Type GetFuncType() const override {
129     Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
130     Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
131     Type llvm_void_ptr_type = getVoidPtrType();
132     return LLVM::LLVMFunctionType::get(
133         llvm_void_ptr_type,
134         llvm::makeArrayRef(
135             {/*void* op_kernel_ctx*/ llvm_void_ptr_type,
136              /*size_t num_elements*/ getIndexType(),
137              /*size_t element_size*/ getIndexType(),
138              /*int32_t output_index*/ llvm_i32_type,
139              /*int32_t num_candidates*/ llvm_i32_type,
140              /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}));
141   }
142 
143  private:
CreateMemRefDescriptor(Location loc,ConversionPatternRewriter & rewriter,MemRefType memref_type,Value allocated_byte_ptr,ArrayRef<Value> sizes) const144   MemRefDescriptor CreateMemRefDescriptor(Location loc,
145                                           ConversionPatternRewriter &rewriter,
146                                           MemRefType memref_type,
147                                           Value allocated_byte_ptr,
148                                           ArrayRef<Value> sizes) const {
149     auto memref_desc = MemRefDescriptor::undef(
150         rewriter, loc, typeConverter->convertType(memref_type));
151 
152     // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr.
153     Value allocated_type_ptr = rewriter.create<LLVM::BitcastOp>(
154         loc, getElementPtrType(memref_type), allocated_byte_ptr);
155     memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr);
156     memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr);
157     memref_desc.setConstantOffset(rewriter, loc, 0);
158 
159     if (memref_type.getRank() == 0) {
160       return memref_desc;
161     }
162 
163     // Compute strides and populate descriptor `size` and `stride` fields.
164     Value stride_carried = createIndexConstant(rewriter, loc, 1);
165     for (int pos = sizes.size() - 1; pos >= 0; --pos) {
166       Value size = sizes[pos];
167       memref_desc.setSize(rewriter, loc, pos, size);
168       memref_desc.setStride(rewriter, loc, pos, stride_carried);
169       // Update stride
170       if (pos > 0) {
171         stride_carried =
172             rewriter.create<LLVM::MulOp>(loc, stride_carried, size);
173       }
174     }
175     return memref_desc;
176   }
177 
ConvertI32ArrayAttrToStackAllocatedArray(Location loc,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const178   std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray(
179       Location loc, llvm::Optional<ArrayAttr> attr,
180       ConversionPatternRewriter *rewriter) const {
181     Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
182     Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
183 
184     // If the attribute is missing or empty, set the element count to 0 and
185     // return NULL.
186     if (!attr.hasValue() || attr.getValue().empty()) {
187       Value zero = rewriter->create<LLVM::ConstantOp>(
188           loc, llvm_i32_type, rewriter->getI32IntegerAttr(0));
189       Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type);
190       return std::make_pair(zero, null_ptr);
191     }
192 
193     // Allocate array to store the elements.
194     auto &array_attr = attr.getValue();
195     Value array_size = rewriter->create<LLVM::ConstantOp>(
196         loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size()));
197     Value array_ptr = rewriter->create<LLVM::AllocaOp>(
198         loc, llvm_i32_ptr_type, array_size, /*alignment=*/0);
199 
200     for (auto &dim : llvm::enumerate(array_attr)) {
201       Value index = rewriter->create<LLVM::ConstantOp>(
202           loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index()));
203       Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type,
204                                                      array_ptr, index);
205       Value elem = rewriter->create<LLVM::ConstantOp>(
206           loc, llvm_i32_type,
207           rewriter->getI32IntegerAttr(
208               dim.value().cast<IntegerAttr>().getInt()));
209       rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr);
210     }
211     return std::make_pair(array_size, array_ptr);
212   }
213 };
214 
215 class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
216  public:
217   using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
218 
matchAndRewrite(TFDeallocOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const219   LogicalResult matchAndRewrite(
220       TFDeallocOp op, ArrayRef<Value> operands,
221       ConversionPatternRewriter &rewriter) const override {
222     TFDeallocOp::Adaptor transformed(operands);
223     MemRefDescriptor memref(transformed.memref());
224 
225     Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
226         op.getLoc(), getVoidPtrType(),
227         memref.allocatedPtr(rewriter, op.getLoc()));
228 
229     // Insert function call.
230     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
231     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
232         op, llvm::None, tf_func_ref,
233         llvm::makeArrayRef({transformed.ctx(), allocated_bytes_ptr}));
234     return success();
235   }
236 
237  protected:
GetFuncName() const238   StringRef GetFuncName() const override { return kCInterfaceDealloc; }
GetFuncType() const239   Type GetFuncType() const override {
240     return LLVM::LLVMFunctionType::get(getVoidType(),
241                                        {getVoidPtrType(), getVoidPtrType()});
242   }
243 };
244 
245 class ReportErrorOpConverter
246     : public ConvertToLLVMCallOpPattern<ReportErrorOp> {
247  public:
248   using ConvertToLLVMCallOpPattern<ReportErrorOp>::ConvertToLLVMCallOpPattern;
249 
matchAndRewrite(ReportErrorOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const250   LogicalResult matchAndRewrite(
251       ReportErrorOp op, ArrayRef<Value> operands,
252       ConversionPatternRewriter &rewriter) const override {
253     ReportErrorOp::Adaptor transformed(operands,
254                                        op.getOperation()->getAttrDictionary());
255 
256     Location loc = op.getLoc();
257     auto module = op->getParentOfType<ModuleOp>();
258     Value message_constant = GenerateErrorMessageConstant(
259         loc, module, transformed.msg().getValue(), rewriter);
260 
261     // Insert function call.
262     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
263     Value error_code = rewriter.create<LLVM::ConstantOp>(
264         loc, typeConverter->convertType(rewriter.getI32Type()),
265         transformed.error_code());
266     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
267 
268         op, llvm::None, tf_func_ref,
269         llvm::makeArrayRef({transformed.ctx(), error_code, message_constant}));
270     return success();
271   }
272 
273  protected:
GetFuncName() const274   StringRef GetFuncName() const override { return kCInterfaceReportError; }
GetFuncType() const275   Type GetFuncType() const override {
276     MLIRContext *ctx = &getTypeConverter()->getContext();
277     auto i8_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
278     auto i32_type = IntegerType::get(ctx, 32);
279     return LLVM::LLVMFunctionType::get(
280         getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type});
281   }
282 
283  private:
284   // Generates an LLVM IR dialect global that contains the name of the given
285   // kernel function as a C string, and returns a pointer to its beginning.
GenerateErrorMessageConstant(Location loc,Operation * module,StringRef message,OpBuilder & builder) const286   Value GenerateErrorMessageConstant(Location loc, Operation *module,
287                                      StringRef message,
288                                      OpBuilder &builder) const {
289     std::string loc_str;
290     llvm::raw_string_ostream loc_stream(loc_str);
291     loc_stream << message << " at ";
292     loc.print(loc_stream);
293 
294     StringRef generated_error(loc_stream.str().c_str());
295 
296     std::string global_name =
297         llvm::formatv("error_message_{0}", llvm::hash_value(generated_error));
298 
299     Operation *global_constant =
300         SymbolTable::lookupNearestSymbolFrom(module, global_name);
301 
302     if (global_constant) {
303       Value globalPtr = builder.create<LLVM::AddressOfOp>(
304           loc, cast<LLVM::GlobalOp>(global_constant));
305 
306       MLIRContext *ctx = &getTypeConverter()->getContext();
307       Value c0 = builder.create<LLVM::ConstantOp>(
308           loc, IntegerType::get(ctx, 64),
309           builder.getIntegerAttr(builder.getIndexType(), 0));
310       return builder.create<LLVM::GEPOp>(
311           loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
312           ValueRange{c0, c0});
313     }
314     return LLVM::createGlobalString(loc, builder, global_name, generated_error,
315                                     LLVM::Linkage::Internal);
316   }
317 };
318 
319 class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
320  public:
321   using ConvertOpToLLVMPattern<NullContextOp>::ConvertOpToLLVMPattern;
322 
matchAndRewrite(NullContextOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const323   LogicalResult matchAndRewrite(
324       NullContextOp op, ArrayRef<Value> operands,
325       ConversionPatternRewriter &rewriter) const override {
326     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
327     return success();
328   }
329 };
330 
331 class NullMemRefOpConverter : public ConvertOpToLLVMPattern<NullMemRefOp> {
332  public:
333   using ConvertOpToLLVMPattern<NullMemRefOp>::ConvertOpToLLVMPattern;
334 
matchAndRewrite(NullMemRefOp null_memref_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const335   LogicalResult matchAndRewrite(
336       NullMemRefOp null_memref_op, ArrayRef<Value> operands,
337       ConversionPatternRewriter &rewriter) const override {
338     mlir::Operation *op = null_memref_op.getOperation();
339 
340     Location loc = op->getLoc();
341     auto result_type = null_memref_op.getType().cast<UnrankedMemRefType>();
342     Type llvm_result_type = typeConverter->convertType(result_type);
343 
344     auto desc =
345         UnrankedMemRefDescriptor::undef(rewriter, loc, llvm_result_type);
346     Value zero = createIndexConstant(rewriter, loc, 0);
347     desc.setRank(rewriter, loc, zero);
348 
349     // Due to the current way of handling unranked memref results escaping, we
350     // have to actually construct a ranked underlying descriptor instead of just
351     // setting its pointer to NULL.
352     SmallVector<Value, 4> sizes;
353     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
354                                            desc, sizes);
355     Value underlying_desc_ptr = rewriter.create<LLVM::AllocaOp>(
356         loc, getVoidPtrType(), sizes.front(), llvm::None);
357 
358     // Populate underlying ranked descriptor.
359     unsigned address_space = result_type.getMemorySpace();
360     Type elem_type = result_type.getElementType();
361     Type llvm_elem_type = typeConverter->convertType(elem_type);
362     Type elem_ptr_ptr_type = LLVM::LLVMPointerType::get(
363         LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
364 
365     auto nullPtr = rewriter.create<LLVM::NullOp>(
366         loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
367     UnrankedMemRefDescriptor::setAllocatedPtr(
368         rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, nullPtr);
369     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
370                                             underlying_desc_ptr,
371                                             elem_ptr_ptr_type, nullPtr);
372     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
373                                         underlying_desc_ptr, elem_ptr_ptr_type,
374                                         zero);
375 
376     desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr);
377     rewriter.replaceOp(op, {desc});
378     return success();
379   }
380 };
381 
382 }  // namespace
383 
PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter * converter,OwningRewritePatternList * patterns)384 void PopulateTFFrameworkToLLVMConversionPatterns(
385     LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
386   // clang-format off
387   patterns->insert<
388       NullContextOpConverter,
389       NullMemRefOpConverter,
390       ReportErrorOpConverter,
391       TFAllocOpConverter,
392       TFDeallocOpConverter
393     >(*converter);
394   // clang-format on
395 }
396 
397 }  // namespace tf_framework
398 }  // namespace kernel_gen
399 }  // namespace mlir
400