• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/Conversion/LLVMCommon/Pattern.h"  // from @llvm-project
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
29 
30 namespace mlir {
31 namespace kernel_gen {
32 namespace tf_framework {
33 namespace {
34 
35 using LLVM::LLVMFuncOp;
36 
37 static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
38 static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
39 static constexpr StringRef kCInterfaceReportError =
40     "_mlir_ciface_tf_report_error";
41 static constexpr StringRef kCInterfaceJITCompile =
42     "_mlir_ciface_tf_jit_compile";
43 static constexpr StringRef kCInterfaceJITExecute =
44     "_mlir_ciface_tf_jit_execute";
45 static constexpr StringRef kJITCodeGlobalBaseName = "jit_module_code";
46 static constexpr StringRef kJITArchitectureGlobalBaseName = "jit_architecture";
47 static constexpr StringRef kErrorMessageGlobalBaseName = "error_message";
48 
49 /// Base class for patterns converting TF Framework ops to function calls.
50 template <typename OpTy>
51 class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
52  public:
53   using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
54 
55   // Attempts to find function symbol in the module, adds it if not found.
getOrInsertTFFunction(PatternRewriter & rewriter,Operation * op) const56   FlatSymbolRefAttr getOrInsertTFFunction(PatternRewriter &rewriter,
57                                           Operation *op) const {
58     ModuleOp module = op->getParentOfType<ModuleOp>();
59     StringRef tf_func_name = GetFuncName();
60     auto tf_func = module.lookupSymbol<LLVMFuncOp>(tf_func_name);
61     if (!tf_func) {
62       OpBuilder::InsertionGuard guard(rewriter);
63       rewriter.setInsertionPointToStart(module.getBody());
64       auto func_type = GetFuncType();
65       tf_func = rewriter.create<LLVMFuncOp>(rewriter.getUnknownLoc(),
66                                             tf_func_name, func_type);
67     }
68     return SymbolRefAttr::get(rewriter.getContext(), tf_func_name);
69   }
70 
71  protected:
72   virtual StringRef GetFuncName() const = 0;
73   virtual Type GetFuncType() const = 0;
74 
CreateOrFindGlobalStringConstant(Location loc,OpBuilder & builder,StringRef base_name,StringRef str) const75   Value CreateOrFindGlobalStringConstant(Location loc, OpBuilder &builder,
76                                          StringRef base_name,
77                                          StringRef str) const {
78     auto module =
79         builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
80     std::string global_name =
81         llvm::formatv("{0}_{1}", base_name, llvm::hash_value(str));
82     Operation *global_constant =
83         SymbolTable::lookupNearestSymbolFrom(module, global_name);
84     if (global_constant) {
85       Value global_ptr = builder.create<LLVM::AddressOfOp>(
86           loc, cast<LLVM::GlobalOp>(global_constant));
87       Value c0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
88                                                   builder.getIndexAttr(0));
89       return builder.create<LLVM::GEPOp>(
90           loc, LLVM::LLVMPointerType::get(builder.getIntegerType(8)),
91           global_ptr, ValueRange{c0, c0});
92     }
93     return LLVM::createGlobalString(loc, builder, global_name, str,
94                                     LLVM::Linkage::Internal);
95   }
96 
ConvertArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter,std::function<Value (Attribute)> create_element) const97   std::pair<Value, Value> ConvertArrayAttrToStackAllocatedArray(
98       Location loc, Type size_ty, Type element_ty,
99       llvm::Optional<ArrayAttr> attr, ConversionPatternRewriter *rewriter,
100       std::function<Value(Attribute)> create_element) const {
101     Type element_ptr_ty = LLVM::LLVMPointerType::get(element_ty);
102 
103     // If the attribute is missing or empty, set the element count to 0 and
104     // return NULL.
105     if (!attr.hasValue() || attr.getValue().empty()) {
106       Value zero = rewriter->create<LLVM::ConstantOp>(
107           loc, size_ty, rewriter->getIntegerAttr(size_ty, 0));
108       Value null_ptr = rewriter->create<LLVM::NullOp>(loc, element_ptr_ty);
109       return std::make_pair(zero, null_ptr);
110     }
111 
112     // Allocate array to store the elements.
113     auto &array_attr = attr.getValue();
114     Value array_size = rewriter->create<LLVM::ConstantOp>(
115         loc, size_ty, rewriter->getIntegerAttr(size_ty, array_attr.size()));
116     Value array_ptr = rewriter->create<LLVM::AllocaOp>(
117         loc, element_ptr_ty, array_size, /*alignment=*/0);
118     for (auto &e : llvm::enumerate(array_attr)) {
119       Value index = rewriter->create<LLVM::ConstantOp>(
120           loc, size_ty, rewriter->getIntegerAttr(size_ty, e.index()));
121       Value element_ptr =
122           rewriter->create<LLVM::GEPOp>(loc, element_ptr_ty, array_ptr, index);
123       Value element = create_element(e.value());
124       rewriter->create<LLVM::StoreOp>(loc, element, element_ptr);
125     }
126     return std::make_pair(array_size, array_ptr);
127   }
128 
ConvertIntegerArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const129   std::pair<Value, Value> ConvertIntegerArrayAttrToStackAllocatedArray(
130       Location loc, Type size_ty, Type element_ty,
131       llvm::Optional<ArrayAttr> attr,
132       ConversionPatternRewriter *rewriter) const {
133     assert(size_ty.isa<IntegerType>() && "expect integer size type");
134     assert(element_ty.isa<IntegerType>() && "expect integer element type");
135     return ConvertArrayAttrToStackAllocatedArray(
136         loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) {
137           return rewriter->create<LLVM::ConstantOp>(
138               loc, element_ty,
139               rewriter->getIntegerAttr(element_ty,
140                                        attr.cast<IntegerAttr>().getInt()));
141         });
142   }
143 
ConvertStrArrayAttrToStackAllocatedArray(Location loc,Type size_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const144   std::pair<Value, Value> ConvertStrArrayAttrToStackAllocatedArray(
145       Location loc, Type size_ty, llvm::Optional<ArrayAttr> attr,
146       ConversionPatternRewriter *rewriter) const {
147     assert(size_ty.isa<IntegerType>() && "expect integer size type");
148     Type element_ty = LLVM::LLVMPointerType::get(rewriter->getI8Type());
149     return ConvertArrayAttrToStackAllocatedArray(
150         loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) {
151           std::string zero_terminated =
152               attr.cast<StringAttr>().getValue().str() + '\00';
153           return CreateOrFindGlobalStringConstant(
154               loc, *rewriter, kJITArchitectureGlobalBaseName, zero_terminated);
155         });
156   }
157 };
158 
159 class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
160  public:
161   using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
162 
matchAndRewrite(TFAllocOp tf_alloc_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const163   LogicalResult matchAndRewrite(
164       TFAllocOp tf_alloc_op, ArrayRef<Value> operands,
165       ConversionPatternRewriter &rewriter) const override {
166     mlir::Operation *op = tf_alloc_op.getOperation();
167     Location loc = op->getLoc();
168     TFAllocOp::Adaptor transformed(operands);
169 
170     MemRefType memref_type = tf_alloc_op.getType();
171 
172     // Get memref descriptor sizes.
173     SmallVector<Value, 4> sizes;
174     SmallVector<Value, 4> strides;
175     Value sizeBytes;
176     getMemRefDescriptorSizes(loc, memref_type,
177                              llvm::to_vector<4>(transformed.dyn_sizes()),
178                              rewriter, sizes, strides, sizeBytes);
179     // Get number of elements.
180     Value num_elements = getNumElements(loc, sizes, rewriter);
181     // Get element size.
182     Value element_size =
183         getSizeInBytes(loc, memref_type.getElementType(), rewriter);
184 
185     // Convert `output_index` or set it to -1 if the attribute is missing.
186     Type llvmInt32Type = IntegerType::get(rewriter.getContext(), 32);
187     Value output_index = rewriter.create<LLVM::ConstantOp>(
188         loc, llvmInt32Type,
189         rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
190                                        ? tf_alloc_op.output_index().getValue()
191                                        : -1));
192 
193     // Convert `candidate_input_indices`.
194     auto candidates_count_and_ptr =
195         ConvertIntegerArrayAttrToStackAllocatedArray(
196             loc, rewriter.getI32Type(), rewriter.getI32Type(),
197             tf_alloc_op.input_indices(), &rewriter);
198 
199     // Insert function call.
200     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
201     Value allocated_byte_ptr =
202         rewriter
203             .create<LLVM::CallOp>(
204                 loc, getVoidPtrType(), tf_func_ref,
205                 llvm::makeArrayRef({transformed.ctx(), num_elements,
206                                     element_size, output_index,
207                                     candidates_count_and_ptr.first,
208                                     candidates_count_and_ptr.second}))
209             .getResult(0);
210 
211     MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
212         loc, rewriter, memref_type, allocated_byte_ptr, sizes);
213 
214     // Return the final value of the descriptor.
215     rewriter.replaceOp(op, {memRefDescriptor});
216     return success();
217   }
218 
219  protected:
GetFuncName() const220   StringRef GetFuncName() const override { return kCInterfaceAlloc; }
221 
GetFuncType() const222   Type GetFuncType() const override {
223     Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
224     Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
225     Type llvm_void_ptr_type = getVoidPtrType();
226     return LLVM::LLVMFunctionType::get(
227         llvm_void_ptr_type,
228         llvm::makeArrayRef(
229             {/*void* op_kernel_ctx*/ llvm_void_ptr_type,
230              /*size_t num_elements*/ getIndexType(),
231              /*size_t element_size*/ getIndexType(),
232              /*int32_t output_index*/ llvm_i32_type,
233              /*int32_t num_candidates*/ llvm_i32_type,
234              /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}));
235   }
236 
237  private:
238   // TODO(pifon): Remove strides computation.
CreateMemRefDescriptor(Location loc,ConversionPatternRewriter & rewriter,MemRefType memref_type,Value allocated_byte_ptr,ArrayRef<Value> sizes) const239   MemRefDescriptor CreateMemRefDescriptor(Location loc,
240                                           ConversionPatternRewriter &rewriter,
241                                           MemRefType memref_type,
242                                           Value allocated_byte_ptr,
243                                           ArrayRef<Value> sizes) const {
244     auto memref_desc = MemRefDescriptor::undef(
245         rewriter, loc, typeConverter->convertType(memref_type));
246 
247     // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr.
248     Value allocated_type_ptr = rewriter.create<LLVM::BitcastOp>(
249         loc, getElementPtrType(memref_type), allocated_byte_ptr);
250     memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr);
251     memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr);
252     memref_desc.setConstantOffset(rewriter, loc, 0);
253 
254     if (memref_type.getRank() == 0) {
255       return memref_desc;
256     }
257 
258     // Compute strides and populate descriptor `size` and `stride` fields.
259     Value stride_carried = createIndexConstant(rewriter, loc, 1);
260     for (int pos = sizes.size() - 1; pos >= 0; --pos) {
261       Value size = sizes[pos];
262       memref_desc.setSize(rewriter, loc, pos, size);
263       memref_desc.setStride(rewriter, loc, pos, stride_carried);
264       // Update stride
265       if (pos > 0) {
266         stride_carried =
267             rewriter.create<LLVM::MulOp>(loc, stride_carried, size);
268       }
269     }
270     return memref_desc;
271   }
272 };
273 
274 class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
275  public:
276   using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
277 
matchAndRewrite(TFDeallocOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const278   LogicalResult matchAndRewrite(
279       TFDeallocOp op, ArrayRef<Value> operands,
280       ConversionPatternRewriter &rewriter) const override {
281     TFDeallocOp::Adaptor transformed(operands);
282     MemRefDescriptor memref(transformed.memref());
283 
284     Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
285         op.getLoc(), getVoidPtrType(),
286         memref.allocatedPtr(rewriter, op.getLoc()));
287 
288     // Insert function call.
289     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
290     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
291         op, llvm::None, tf_func_ref,
292         llvm::makeArrayRef({transformed.ctx(), allocated_bytes_ptr}));
293     return success();
294   }
295 
296  protected:
GetFuncName() const297   StringRef GetFuncName() const override { return kCInterfaceDealloc; }
GetFuncType() const298   Type GetFuncType() const override {
299     return LLVM::LLVMFunctionType::get(getVoidType(),
300                                        {getVoidPtrType(), getVoidPtrType()});
301   }
302 };
303 
304 class JITCompileFromStrOpConverter
305     : public ConvertToLLVMCallOpPattern<JITCompileFromStrOp> {
306   using ConvertToLLVMCallOpPattern<
307       JITCompileFromStrOp>::ConvertToLLVMCallOpPattern;
308 
matchAndRewrite(JITCompileFromStrOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const309   LogicalResult matchAndRewrite(
310       JITCompileFromStrOp op, ArrayRef<Value> operands,
311       ConversionPatternRewriter &rewriter) const override {
312     JITCompileFromStrOp::Adaptor transformed(operands);
313     if (transformed.ctx() == nullptr) return failure();
314     auto loc = op.getLoc();
315     std::string zero_terminated_code = op.code().str() + '\00';
316     Value jit_module_code = CreateOrFindGlobalStringConstant(
317         loc, rewriter, kJITCodeGlobalBaseName, zero_terminated_code);
318     std::pair<Value, Value> architectures =
319         ConvertStrArrayAttrToStackAllocatedArray(loc, rewriter.getI64Type(),
320                                                  op.architectures(), &rewriter);
321     std::pair<Value, Value> tile_sizes =
322         ConvertIntegerArrayAttrToStackAllocatedArray(loc, rewriter.getI64Type(),
323                                                      rewriter.getI64Type(),
324                                                      op.tileSizes(), &rewriter);
325     std::pair<Value, Value> unroll_factors =
326         ConvertIntegerArrayAttrToStackAllocatedArray(
327             loc, rewriter.getI64Type(), rewriter.getI64Type(),
328             op.unrollFactors(), &rewriter);
329     Value max_supported_rank = rewriter.create<LLVM::ConstantOp>(
330         loc, rewriter.getI64Type(), op.maxSupportedRankAttr());
331     Value enable_ftz = rewriter.create<LLVM::ConstantOp>(
332         loc, rewriter.getI1Type(), op.enableFtzAttr());
333     Value cpu_codegen = rewriter.create<LLVM::ConstantOp>(
334         loc, rewriter.getI1Type(), op.cpuCodegenAttr());
335     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
336     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
337         op, getVoidPtrType(), tf_func_ref,
338         llvm::makeArrayRef({transformed.ctx(), jit_module_code,
339                             architectures.first, architectures.second,
340                             tile_sizes.first, tile_sizes.second,
341                             unroll_factors.first, unroll_factors.second,
342                             max_supported_rank, enable_ftz, cpu_codegen}));
343     return success();
344   }
345 
346  protected:
GetFuncName() const347   StringRef GetFuncName() const override { return kCInterfaceJITCompile; }
348 
GetFuncType() const349   Type GetFuncType() const override {
350     auto i8_ptr_ty =
351         LLVM::LLVMPointerType::get(IntegerType::get(getContext(), 8));
352     auto i8_ptr_ptr_ty = LLVM::LLVMPointerType::get(i8_ptr_ty);
353     auto i64_ty = IntegerType::get(getContext(), 64);
354     Type i64_ptr_ty = LLVM::LLVMPointerType::get(i64_ty);
355     auto i1_ty = IntegerType::get(getContext(), 1);
356     return LLVM::LLVMFunctionType::get(
357         getVoidPtrType(), {/*void* op_kernel_ctx*/ getVoidPtrType(),
358                            /*char* code*/ i8_ptr_ty,
359                            /*int64_t num_architectures*/ i64_ty,
360                            /*int64_t* architectures_ptr*/ i8_ptr_ptr_ty,
361                            /*int64_t num_tile_sizes*/ i64_ty,
362                            /*int64_t* tile_sizes_ptr*/ i64_ptr_ty,
363                            /*int64_t num_unroll_factors*/ i64_ty,
364                            /*int64_t* unroll_factors_ptr*/ i64_ptr_ty,
365                            /*int64_t max_supported_rank*/ i64_ty,
366                            /*bool enable_ftz*/ i1_ty,
367                            /*bool cpu_codegen*/ i1_ty});
368   }
369 };
370 
371 class JITExecuteOpConverter : public ConvertToLLVMCallOpPattern<JITExecuteOp> {
372  public:
373   using ConvertToLLVMCallOpPattern<JITExecuteOp>::ConvertToLLVMCallOpPattern;
374 
matchAndRewrite(JITExecuteOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const375   LogicalResult matchAndRewrite(
376       JITExecuteOp op, ArrayRef<Value> operands,
377       ConversionPatternRewriter &rewriter) const override {
378     // The TF context must be known for a succesful lowering. Also, we support
379     // only one result.
380     JITExecuteOp::Adaptor transformed(operands, op->getAttrDictionary());
381     if (transformed.ctx() == nullptr || op.operands().empty() ||
382         op.getNumResults() != 1)
383       return failure();
384 
385     // Allocate result on stack.
386     auto loc = op.getLoc();
387     Type result_ty =
388         getTypeConverter()->convertType(op->getResultTypes().front());
389     Type result_ptr_ty = LLVM::LLVMPointerType::get(result_ty);
390     Type i64_ty = rewriter.getI64Type();
391     Value one = rewriter.create<LLVM::ConstantOp>(
392         loc, i64_ty, rewriter.getI64IntegerAttr(1));
393     auto result_ptr =
394         rewriter.create<LLVM::AllocaOp>(loc, result_ptr_ty, one, llvm::None);
395     Type void_ptr_ty = getVoidPtrType();
396     auto result_void_ptr =
397         rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, result_ptr);
398 
399     // Pass the buffer arguments as a stack-allocated array.
400     Type arg_ptr_ty =
401         LLVM::LLVMPointerType::get(transformed.operands().front().getType());
402     Value num_args = rewriter.create<LLVM::ConstantOp>(
403         loc, i64_ty, rewriter.getI64IntegerAttr(transformed.operands().size()));
404     Value args_ptr = rewriter.create<LLVM::AllocaOp>(loc, arg_ptr_ty, num_args,
405                                                      /*alignment=*/0);
406     for (auto it : llvm::enumerate(transformed.operands())) {
407       Value index = rewriter.create<LLVM::ConstantOp>(
408           loc, i64_ty, rewriter.getI64IntegerAttr(it.index()));
409       Value element_ptr =
410           rewriter.create<LLVM::GEPOp>(loc, arg_ptr_ty, args_ptr, index);
411       rewriter.create<LLVM::StoreOp>(loc, it.value(), element_ptr);
412     }
413     auto args_void_ptr =
414         rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, args_ptr);
415 
416     // Materialize runtime call.
417     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
418     rewriter.create<LLVM::CallOp>(
419         loc, llvm::None, tf_func_ref,
420         ValueRange{transformed.ctx(), transformed.callable(), result_void_ptr,
421                    num_args, args_void_ptr});
422 
423     // Copy result (including the descriptor) to a stack-allocated buffer and
424     // free the old descriptor.
425     llvm::SmallVector<Value, 1> final_result = {
426         rewriter.create<LLVM::LoadOp>(loc, result_ptr)};
427     if (failed(copyUnrankedDescriptors(rewriter, loc, op->getResultTypes(),
428                                        final_result,
429                                        /*toDynamic=*/false))) {
430       return failure();
431     }
432 
433     rewriter.replaceOp(op, final_result.front());
434     return success();
435   }
436 
437  protected:
GetFuncName() const438   StringRef GetFuncName() const override { return kCInterfaceJITExecute; }
439 
GetFuncType() const440   Type GetFuncType() const override {
441     auto i64_ty = IntegerType::get(getContext(), 64);
442     auto void_ptr_ty = getVoidPtrType();
443     return LLVM::LLVMFunctionType::get(getVoidType(),
444                                        {/*void* op_kernel_ctx*/ void_ptr_ty,
445                                         /*void* callable*/ void_ptr_ty,
446                                         /*void* result*/ void_ptr_ty,
447                                         /*int64_t num_args*/ i64_ty,
448                                         /*void* args_ptr*/ void_ptr_ty});
449   }
450 };
451 
452 class ReportErrorOpConverter
453     : public ConvertToLLVMCallOpPattern<ReportErrorOp> {
454  public:
455   using ConvertToLLVMCallOpPattern<ReportErrorOp>::ConvertToLLVMCallOpPattern;
456 
matchAndRewrite(ReportErrorOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const457   LogicalResult matchAndRewrite(
458       ReportErrorOp op, ArrayRef<Value> operands,
459       ConversionPatternRewriter &rewriter) const override {
460     ReportErrorOp::Adaptor transformed(operands,
461                                        op.getOperation()->getAttrDictionary());
462 
463     Location loc = op.getLoc();
464     auto module = op->getParentOfType<ModuleOp>();
465     Value message_constant = GenerateErrorMessageConstant(
466         loc, module, transformed.msg().getValue(), rewriter);
467 
468     // Insert function call.
469     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
470     Value error_code = rewriter.create<LLVM::ConstantOp>(
471         loc, typeConverter->convertType(rewriter.getI32Type()),
472         transformed.error_code());
473     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
474         op, llvm::None, tf_func_ref,
475         llvm::makeArrayRef({transformed.ctx(), error_code, message_constant}));
476     return success();
477   }
478 
479  protected:
GetFuncName() const480   StringRef GetFuncName() const override { return kCInterfaceReportError; }
GetFuncType() const481   Type GetFuncType() const override {
482     MLIRContext *ctx = &getTypeConverter()->getContext();
483     auto i8_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
484     auto i32_type = IntegerType::get(ctx, 32);
485     return LLVM::LLVMFunctionType::get(
486         getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type});
487   }
488 
489  private:
490   // Generates an LLVM IR dialect global that contains the name of the given
491   // kernel function as a C string, and returns a pointer to its beginning.
GenerateErrorMessageConstant(Location loc,Operation * module,StringRef message,OpBuilder & builder) const492   Value GenerateErrorMessageConstant(Location loc, Operation *module,
493                                      StringRef message,
494                                      OpBuilder &builder) const {
495     std::string err_str;
496     llvm::raw_string_ostream err_stream(err_str);
497     err_stream << message;
498     if (!loc.isa<UnknownLoc>()) {
499       err_stream << " at ";
500       loc.print(err_stream);
501     }
502     err_stream << '\00';
503     StringRef generated_error(err_stream.str());
504     return CreateOrFindGlobalStringConstant(
505         loc, builder, kErrorMessageGlobalBaseName, generated_error);
506   }
507 };
508 
509 class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
510  public:
511   using ConvertOpToLLVMPattern<NullContextOp>::ConvertOpToLLVMPattern;
512 
matchAndRewrite(NullContextOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const513   LogicalResult matchAndRewrite(
514       NullContextOp op, ArrayRef<Value> operands,
515       ConversionPatternRewriter &rewriter) const override {
516     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
517     return success();
518   }
519 };
520 
521 class NullMemRefOpConverter : public ConvertOpToLLVMPattern<NullMemRefOp> {
522  public:
523   using ConvertOpToLLVMPattern<NullMemRefOp>::ConvertOpToLLVMPattern;
524 
matchAndRewrite(NullMemRefOp null_memref_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const525   LogicalResult matchAndRewrite(
526       NullMemRefOp null_memref_op, ArrayRef<Value> operands,
527       ConversionPatternRewriter &rewriter) const override {
528     Location loc = null_memref_op->getLoc();
529     LLVMTypeConverter type_converter = *getTypeConverter();
530     mlir::Operation *op = null_memref_op.getOperation();
531 
532     auto shaped_result_type = null_memref_op.getType().cast<BaseMemRefType>();
533     unsigned address_space = shaped_result_type.getMemorySpaceAsInt();
534 
535     Type elem_type = shaped_result_type.getElementType();
536     Type llvm_elem_type = type_converter.convertType(elem_type);
537 
538     Value zero = createIndexConstant(rewriter, loc, 0);
539     if (auto result_type = null_memref_op.getType().dyn_cast<MemRefType>()) {
540       // Set all dynamic sizes to 1 and compute fake strides.
541       SmallVector<Value, 4> dyn_sizes(result_type.getNumDynamicDims(),
542                                       createIndexConstant(rewriter, loc, 1));
543       SmallVector<Value, 4> sizes, strides;
544       Value sizeBytes;
545       getMemRefDescriptorSizes(loc, result_type, dyn_sizes, rewriter, sizes,
546                                strides, sizeBytes);
547 
548       // Prepare packed args [allocatedPtr, alignedPtr, offset, sizes, strides]
549       // to create a memref descriptor.
550       Value null = rewriter.create<LLVM::NullOp>(
551           loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
552       SmallVector<Value, 12> packed_values{null, null, zero};
553       packed_values.append(sizes);
554       packed_values.append(strides);
555 
556       rewriter.replaceOp(
557           op, MemRefDescriptor::pack(rewriter, loc, type_converter, result_type,
558                                      packed_values));
559       return success();
560     }
561 
562     auto result_type = null_memref_op.getType().cast<UnrankedMemRefType>();
563     Type llvm_result_type = type_converter.convertType(result_type);
564 
565     auto desc =
566         UnrankedMemRefDescriptor::undef(rewriter, loc, llvm_result_type);
567     desc.setRank(rewriter, loc, zero);
568 
569     // Due to the current way of handling unranked memref results escaping, we
570     // have to actually construct a ranked underlying descriptor instead of just
571     // setting its pointer to NULL.
572     SmallVector<Value, 4> sizes;
573     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
574                                            desc, sizes);
575     Value underlying_desc_ptr = rewriter.create<LLVM::AllocaOp>(
576         loc, getVoidPtrType(), sizes.front(), llvm::None);
577 
578     // Populate underlying ranked descriptor.
579     Type elem_ptr_ptr_type = LLVM::LLVMPointerType::get(
580         LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
581 
582     Value null = rewriter.create<LLVM::NullOp>(
583         loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
584     UnrankedMemRefDescriptor::setAllocatedPtr(
585         rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, null);
586     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
587                                             underlying_desc_ptr,
588                                             elem_ptr_ptr_type, null);
589     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
590                                         underlying_desc_ptr, elem_ptr_ptr_type,
591                                         zero);
592 
593     desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr);
594     rewriter.replaceOp(op, {desc});
595     return success();
596   }
597 };
598 
599 class IsValidMemRefOpConverter
600     : public ConvertOpToLLVMPattern<IsValidMemRefOp> {
601  public:
602   using ConvertOpToLLVMPattern<IsValidMemRefOp>::ConvertOpToLLVMPattern;
603 
matchAndRewrite(IsValidMemRefOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const604   LogicalResult matchAndRewrite(
605       IsValidMemRefOp op, ArrayRef<Value> operands,
606       ConversionPatternRewriter &rewriter) const override {
607     Location loc = op.getLoc();
608     MemRefDescriptor desc(IsValidMemRefOp::Adaptor(operands).arg());
609 
610     // Compare every size in the descriptor to 0 to check num_elements == 0.
611     int64_t rank = op.arg().getType().cast<MemRefType>().getRank();
612     Value is_empty_shape = rewriter.create<LLVM::ConstantOp>(
613         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
614     Value zero = createIndexConstant(rewriter, loc, 0);
615     for (int i = 0; i < rank; ++i) {
616       Value size = desc.size(rewriter, loc, i);
617       Value is_zero_size = rewriter.create<LLVM::ICmpOp>(
618           loc, rewriter.getI1Type(), LLVM::ICmpPredicate::eq, size, zero);
619       is_empty_shape =
620           rewriter.create<LLVM::OrOp>(loc, is_empty_shape, is_zero_size);
621     }
622 
623     Value ptr = rewriter.create<LLVM::BitcastOp>(
624         loc, getVoidPtrType(), desc.allocatedPtr(rewriter, loc));
625     Value null = rewriter.create<LLVM::NullOp>(loc, getVoidPtrType());
626     Value is_not_nullptr = rewriter.create<LLVM::ICmpOp>(
627         loc, rewriter.getI1Type(), LLVM::ICmpPredicate::ne, ptr, null);
628 
629     // Valid memref = ptr != NULL || num_elements == 0;
630     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, is_not_nullptr, is_empty_shape);
631     return success();
632   }
633 };
634 
635 }  // namespace
636 
PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter * converter,RewritePatternSet * patterns)637 void PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter *converter,
638                                                  RewritePatternSet *patterns) {
639   // clang-format off
640   patterns->insert<
641       IsValidMemRefOpConverter,
642       JITCompileFromStrOpConverter,
643       JITExecuteOpConverter,
644       NullContextOpConverter,
645       NullMemRefOpConverter,
646       ReportErrorOpConverter,
647       TFAllocOpConverter,
648       TFDeallocOpConverter>(*converter);
649   // clang-format on
650 }
651 
652 }  // namespace tf_framework
653 }  // namespace kernel_gen
654 }  // namespace mlir
655