1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/Conversion/LLVMCommon/Pattern.h" // from @llvm-project
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
29
30 namespace mlir {
31 namespace kernel_gen {
32 namespace tf_framework {
33 namespace {
34
35 using LLVM::LLVMFuncOp;
36
37 static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
38 static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
39 static constexpr StringRef kCInterfaceReportError =
40 "_mlir_ciface_tf_report_error";
41 static constexpr StringRef kCInterfaceJITCompile =
42 "_mlir_ciface_tf_jit_compile";
43 static constexpr StringRef kCInterfaceJITExecute =
44 "_mlir_ciface_tf_jit_execute";
45 static constexpr StringRef kJITCodeGlobalBaseName = "jit_module_code";
46 static constexpr StringRef kJITArchitectureGlobalBaseName = "jit_architecture";
47 static constexpr StringRef kErrorMessageGlobalBaseName = "error_message";
48
49 /// Base class for patterns converting TF Framework ops to function calls.
50 template <typename OpTy>
51 class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
52 public:
53 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
54
55 // Attempts to find function symbol in the module, adds it if not found.
getOrInsertTFFunction(PatternRewriter & rewriter,Operation * op) const56 FlatSymbolRefAttr getOrInsertTFFunction(PatternRewriter &rewriter,
57 Operation *op) const {
58 ModuleOp module = op->getParentOfType<ModuleOp>();
59 StringRef tf_func_name = GetFuncName();
60 auto tf_func = module.lookupSymbol<LLVMFuncOp>(tf_func_name);
61 if (!tf_func) {
62 OpBuilder::InsertionGuard guard(rewriter);
63 rewriter.setInsertionPointToStart(module.getBody());
64 auto func_type = GetFuncType();
65 tf_func = rewriter.create<LLVMFuncOp>(rewriter.getUnknownLoc(),
66 tf_func_name, func_type);
67 }
68 return SymbolRefAttr::get(rewriter.getContext(), tf_func_name);
69 }
70
71 protected:
72 virtual StringRef GetFuncName() const = 0;
73 virtual Type GetFuncType() const = 0;
74
CreateOrFindGlobalStringConstant(Location loc,OpBuilder & builder,StringRef base_name,StringRef str) const75 Value CreateOrFindGlobalStringConstant(Location loc, OpBuilder &builder,
76 StringRef base_name,
77 StringRef str) const {
78 auto module =
79 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
80 std::string global_name =
81 llvm::formatv("{0}_{1}", base_name, llvm::hash_value(str));
82 Operation *global_constant =
83 SymbolTable::lookupNearestSymbolFrom(module, global_name);
84 if (global_constant) {
85 Value global_ptr = builder.create<LLVM::AddressOfOp>(
86 loc, cast<LLVM::GlobalOp>(global_constant));
87 Value c0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
88 builder.getIndexAttr(0));
89 return builder.create<LLVM::GEPOp>(
90 loc, LLVM::LLVMPointerType::get(builder.getIntegerType(8)),
91 global_ptr, ValueRange{c0, c0});
92 }
93 return LLVM::createGlobalString(loc, builder, global_name, str,
94 LLVM::Linkage::Internal);
95 }
96
ConvertArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter,std::function<Value (Attribute)> create_element) const97 std::pair<Value, Value> ConvertArrayAttrToStackAllocatedArray(
98 Location loc, Type size_ty, Type element_ty,
99 llvm::Optional<ArrayAttr> attr, ConversionPatternRewriter *rewriter,
100 std::function<Value(Attribute)> create_element) const {
101 Type element_ptr_ty = LLVM::LLVMPointerType::get(element_ty);
102
103 // If the attribute is missing or empty, set the element count to 0 and
104 // return NULL.
105 if (!attr.hasValue() || attr.getValue().empty()) {
106 Value zero = rewriter->create<LLVM::ConstantOp>(
107 loc, size_ty, rewriter->getIntegerAttr(size_ty, 0));
108 Value null_ptr = rewriter->create<LLVM::NullOp>(loc, element_ptr_ty);
109 return std::make_pair(zero, null_ptr);
110 }
111
112 // Allocate array to store the elements.
113 auto &array_attr = attr.getValue();
114 Value array_size = rewriter->create<LLVM::ConstantOp>(
115 loc, size_ty, rewriter->getIntegerAttr(size_ty, array_attr.size()));
116 Value array_ptr = rewriter->create<LLVM::AllocaOp>(
117 loc, element_ptr_ty, array_size, /*alignment=*/0);
118 for (auto &e : llvm::enumerate(array_attr)) {
119 Value index = rewriter->create<LLVM::ConstantOp>(
120 loc, size_ty, rewriter->getIntegerAttr(size_ty, e.index()));
121 Value element_ptr =
122 rewriter->create<LLVM::GEPOp>(loc, element_ptr_ty, array_ptr, index);
123 Value element = create_element(e.value());
124 rewriter->create<LLVM::StoreOp>(loc, element, element_ptr);
125 }
126 return std::make_pair(array_size, array_ptr);
127 }
128
ConvertIntegerArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const129 std::pair<Value, Value> ConvertIntegerArrayAttrToStackAllocatedArray(
130 Location loc, Type size_ty, Type element_ty,
131 llvm::Optional<ArrayAttr> attr,
132 ConversionPatternRewriter *rewriter) const {
133 assert(size_ty.isa<IntegerType>() && "expect integer size type");
134 assert(element_ty.isa<IntegerType>() && "expect integer element type");
135 return ConvertArrayAttrToStackAllocatedArray(
136 loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) {
137 return rewriter->create<LLVM::ConstantOp>(
138 loc, element_ty,
139 rewriter->getIntegerAttr(element_ty,
140 attr.cast<IntegerAttr>().getInt()));
141 });
142 }
143
ConvertStrArrayAttrToStackAllocatedArray(Location loc,Type size_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const144 std::pair<Value, Value> ConvertStrArrayAttrToStackAllocatedArray(
145 Location loc, Type size_ty, llvm::Optional<ArrayAttr> attr,
146 ConversionPatternRewriter *rewriter) const {
147 assert(size_ty.isa<IntegerType>() && "expect integer size type");
148 Type element_ty = LLVM::LLVMPointerType::get(rewriter->getI8Type());
149 return ConvertArrayAttrToStackAllocatedArray(
150 loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) {
151 std::string zero_terminated =
152 attr.cast<StringAttr>().getValue().str() + '\00';
153 return CreateOrFindGlobalStringConstant(
154 loc, *rewriter, kJITArchitectureGlobalBaseName, zero_terminated);
155 });
156 }
157 };
158
159 class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
160 public:
161 using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
162
matchAndRewrite(TFAllocOp tf_alloc_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const163 LogicalResult matchAndRewrite(
164 TFAllocOp tf_alloc_op, ArrayRef<Value> operands,
165 ConversionPatternRewriter &rewriter) const override {
166 mlir::Operation *op = tf_alloc_op.getOperation();
167 Location loc = op->getLoc();
168 TFAllocOp::Adaptor transformed(operands);
169
170 MemRefType memref_type = tf_alloc_op.getType();
171
172 // Get memref descriptor sizes.
173 SmallVector<Value, 4> sizes;
174 SmallVector<Value, 4> strides;
175 Value sizeBytes;
176 getMemRefDescriptorSizes(loc, memref_type,
177 llvm::to_vector<4>(transformed.dyn_sizes()),
178 rewriter, sizes, strides, sizeBytes);
179 // Get number of elements.
180 Value num_elements = getNumElements(loc, sizes, rewriter);
181 // Get element size.
182 Value element_size =
183 getSizeInBytes(loc, memref_type.getElementType(), rewriter);
184
185 // Convert `output_index` or set it to -1 if the attribute is missing.
186 Type llvmInt32Type = IntegerType::get(rewriter.getContext(), 32);
187 Value output_index = rewriter.create<LLVM::ConstantOp>(
188 loc, llvmInt32Type,
189 rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
190 ? tf_alloc_op.output_index().getValue()
191 : -1));
192
193 // Convert `candidate_input_indices`.
194 auto candidates_count_and_ptr =
195 ConvertIntegerArrayAttrToStackAllocatedArray(
196 loc, rewriter.getI32Type(), rewriter.getI32Type(),
197 tf_alloc_op.input_indices(), &rewriter);
198
199 // Insert function call.
200 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
201 Value allocated_byte_ptr =
202 rewriter
203 .create<LLVM::CallOp>(
204 loc, getVoidPtrType(), tf_func_ref,
205 llvm::makeArrayRef({transformed.ctx(), num_elements,
206 element_size, output_index,
207 candidates_count_and_ptr.first,
208 candidates_count_and_ptr.second}))
209 .getResult(0);
210
211 MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
212 loc, rewriter, memref_type, allocated_byte_ptr, sizes);
213
214 // Return the final value of the descriptor.
215 rewriter.replaceOp(op, {memRefDescriptor});
216 return success();
217 }
218
219 protected:
GetFuncName() const220 StringRef GetFuncName() const override { return kCInterfaceAlloc; }
221
GetFuncType() const222 Type GetFuncType() const override {
223 Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
224 Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
225 Type llvm_void_ptr_type = getVoidPtrType();
226 return LLVM::LLVMFunctionType::get(
227 llvm_void_ptr_type,
228 llvm::makeArrayRef(
229 {/*void* op_kernel_ctx*/ llvm_void_ptr_type,
230 /*size_t num_elements*/ getIndexType(),
231 /*size_t element_size*/ getIndexType(),
232 /*int32_t output_index*/ llvm_i32_type,
233 /*int32_t num_candidates*/ llvm_i32_type,
234 /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}));
235 }
236
237 private:
238 // TODO(pifon): Remove strides computation.
CreateMemRefDescriptor(Location loc,ConversionPatternRewriter & rewriter,MemRefType memref_type,Value allocated_byte_ptr,ArrayRef<Value> sizes) const239 MemRefDescriptor CreateMemRefDescriptor(Location loc,
240 ConversionPatternRewriter &rewriter,
241 MemRefType memref_type,
242 Value allocated_byte_ptr,
243 ArrayRef<Value> sizes) const {
244 auto memref_desc = MemRefDescriptor::undef(
245 rewriter, loc, typeConverter->convertType(memref_type));
246
247 // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr.
248 Value allocated_type_ptr = rewriter.create<LLVM::BitcastOp>(
249 loc, getElementPtrType(memref_type), allocated_byte_ptr);
250 memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr);
251 memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr);
252 memref_desc.setConstantOffset(rewriter, loc, 0);
253
254 if (memref_type.getRank() == 0) {
255 return memref_desc;
256 }
257
258 // Compute strides and populate descriptor `size` and `stride` fields.
259 Value stride_carried = createIndexConstant(rewriter, loc, 1);
260 for (int pos = sizes.size() - 1; pos >= 0; --pos) {
261 Value size = sizes[pos];
262 memref_desc.setSize(rewriter, loc, pos, size);
263 memref_desc.setStride(rewriter, loc, pos, stride_carried);
264 // Update stride
265 if (pos > 0) {
266 stride_carried =
267 rewriter.create<LLVM::MulOp>(loc, stride_carried, size);
268 }
269 }
270 return memref_desc;
271 }
272 };
273
274 class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
275 public:
276 using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
277
matchAndRewrite(TFDeallocOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const278 LogicalResult matchAndRewrite(
279 TFDeallocOp op, ArrayRef<Value> operands,
280 ConversionPatternRewriter &rewriter) const override {
281 TFDeallocOp::Adaptor transformed(operands);
282 MemRefDescriptor memref(transformed.memref());
283
284 Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
285 op.getLoc(), getVoidPtrType(),
286 memref.allocatedPtr(rewriter, op.getLoc()));
287
288 // Insert function call.
289 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
290 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
291 op, llvm::None, tf_func_ref,
292 llvm::makeArrayRef({transformed.ctx(), allocated_bytes_ptr}));
293 return success();
294 }
295
296 protected:
GetFuncName() const297 StringRef GetFuncName() const override { return kCInterfaceDealloc; }
GetFuncType() const298 Type GetFuncType() const override {
299 return LLVM::LLVMFunctionType::get(getVoidType(),
300 {getVoidPtrType(), getVoidPtrType()});
301 }
302 };
303
304 class JITCompileFromStrOpConverter
305 : public ConvertToLLVMCallOpPattern<JITCompileFromStrOp> {
306 using ConvertToLLVMCallOpPattern<
307 JITCompileFromStrOp>::ConvertToLLVMCallOpPattern;
308
matchAndRewrite(JITCompileFromStrOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const309 LogicalResult matchAndRewrite(
310 JITCompileFromStrOp op, ArrayRef<Value> operands,
311 ConversionPatternRewriter &rewriter) const override {
312 JITCompileFromStrOp::Adaptor transformed(operands);
313 if (transformed.ctx() == nullptr) return failure();
314 auto loc = op.getLoc();
315 std::string zero_terminated_code = op.code().str() + '\00';
316 Value jit_module_code = CreateOrFindGlobalStringConstant(
317 loc, rewriter, kJITCodeGlobalBaseName, zero_terminated_code);
318 std::pair<Value, Value> architectures =
319 ConvertStrArrayAttrToStackAllocatedArray(loc, rewriter.getI64Type(),
320 op.architectures(), &rewriter);
321 std::pair<Value, Value> tile_sizes =
322 ConvertIntegerArrayAttrToStackAllocatedArray(loc, rewriter.getI64Type(),
323 rewriter.getI64Type(),
324 op.tileSizes(), &rewriter);
325 std::pair<Value, Value> unroll_factors =
326 ConvertIntegerArrayAttrToStackAllocatedArray(
327 loc, rewriter.getI64Type(), rewriter.getI64Type(),
328 op.unrollFactors(), &rewriter);
329 Value max_supported_rank = rewriter.create<LLVM::ConstantOp>(
330 loc, rewriter.getI64Type(), op.maxSupportedRankAttr());
331 Value enable_ftz = rewriter.create<LLVM::ConstantOp>(
332 loc, rewriter.getI1Type(), op.enableFtzAttr());
333 Value cpu_codegen = rewriter.create<LLVM::ConstantOp>(
334 loc, rewriter.getI1Type(), op.cpuCodegenAttr());
335 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
336 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
337 op, getVoidPtrType(), tf_func_ref,
338 llvm::makeArrayRef({transformed.ctx(), jit_module_code,
339 architectures.first, architectures.second,
340 tile_sizes.first, tile_sizes.second,
341 unroll_factors.first, unroll_factors.second,
342 max_supported_rank, enable_ftz, cpu_codegen}));
343 return success();
344 }
345
346 protected:
GetFuncName() const347 StringRef GetFuncName() const override { return kCInterfaceJITCompile; }
348
GetFuncType() const349 Type GetFuncType() const override {
350 auto i8_ptr_ty =
351 LLVM::LLVMPointerType::get(IntegerType::get(getContext(), 8));
352 auto i8_ptr_ptr_ty = LLVM::LLVMPointerType::get(i8_ptr_ty);
353 auto i64_ty = IntegerType::get(getContext(), 64);
354 Type i64_ptr_ty = LLVM::LLVMPointerType::get(i64_ty);
355 auto i1_ty = IntegerType::get(getContext(), 1);
356 return LLVM::LLVMFunctionType::get(
357 getVoidPtrType(), {/*void* op_kernel_ctx*/ getVoidPtrType(),
358 /*char* code*/ i8_ptr_ty,
359 /*int64_t num_architectures*/ i64_ty,
360 /*int64_t* architectures_ptr*/ i8_ptr_ptr_ty,
361 /*int64_t num_tile_sizes*/ i64_ty,
362 /*int64_t* tile_sizes_ptr*/ i64_ptr_ty,
363 /*int64_t num_unroll_factors*/ i64_ty,
364 /*int64_t* unroll_factors_ptr*/ i64_ptr_ty,
365 /*int64_t max_supported_rank*/ i64_ty,
366 /*bool enable_ftz*/ i1_ty,
367 /*bool cpu_codegen*/ i1_ty});
368 }
369 };
370
371 class JITExecuteOpConverter : public ConvertToLLVMCallOpPattern<JITExecuteOp> {
372 public:
373 using ConvertToLLVMCallOpPattern<JITExecuteOp>::ConvertToLLVMCallOpPattern;
374
matchAndRewrite(JITExecuteOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const375 LogicalResult matchAndRewrite(
376 JITExecuteOp op, ArrayRef<Value> operands,
377 ConversionPatternRewriter &rewriter) const override {
378 // The TF context must be known for a succesful lowering. Also, we support
379 // only one result.
380 JITExecuteOp::Adaptor transformed(operands, op->getAttrDictionary());
381 if (transformed.ctx() == nullptr || op.operands().empty() ||
382 op.getNumResults() != 1)
383 return failure();
384
385 // Allocate result on stack.
386 auto loc = op.getLoc();
387 Type result_ty =
388 getTypeConverter()->convertType(op->getResultTypes().front());
389 Type result_ptr_ty = LLVM::LLVMPointerType::get(result_ty);
390 Type i64_ty = rewriter.getI64Type();
391 Value one = rewriter.create<LLVM::ConstantOp>(
392 loc, i64_ty, rewriter.getI64IntegerAttr(1));
393 auto result_ptr =
394 rewriter.create<LLVM::AllocaOp>(loc, result_ptr_ty, one, llvm::None);
395 Type void_ptr_ty = getVoidPtrType();
396 auto result_void_ptr =
397 rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, result_ptr);
398
399 // Pass the buffer arguments as a stack-allocated array.
400 Type arg_ptr_ty =
401 LLVM::LLVMPointerType::get(transformed.operands().front().getType());
402 Value num_args = rewriter.create<LLVM::ConstantOp>(
403 loc, i64_ty, rewriter.getI64IntegerAttr(transformed.operands().size()));
404 Value args_ptr = rewriter.create<LLVM::AllocaOp>(loc, arg_ptr_ty, num_args,
405 /*alignment=*/0);
406 for (auto it : llvm::enumerate(transformed.operands())) {
407 Value index = rewriter.create<LLVM::ConstantOp>(
408 loc, i64_ty, rewriter.getI64IntegerAttr(it.index()));
409 Value element_ptr =
410 rewriter.create<LLVM::GEPOp>(loc, arg_ptr_ty, args_ptr, index);
411 rewriter.create<LLVM::StoreOp>(loc, it.value(), element_ptr);
412 }
413 auto args_void_ptr =
414 rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, args_ptr);
415
416 // Materialize runtime call.
417 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
418 rewriter.create<LLVM::CallOp>(
419 loc, llvm::None, tf_func_ref,
420 ValueRange{transformed.ctx(), transformed.callable(), result_void_ptr,
421 num_args, args_void_ptr});
422
423 // Copy result (including the descriptor) to a stack-allocated buffer and
424 // free the old descriptor.
425 llvm::SmallVector<Value, 1> final_result = {
426 rewriter.create<LLVM::LoadOp>(loc, result_ptr)};
427 if (failed(copyUnrankedDescriptors(rewriter, loc, op->getResultTypes(),
428 final_result,
429 /*toDynamic=*/false))) {
430 return failure();
431 }
432
433 rewriter.replaceOp(op, final_result.front());
434 return success();
435 }
436
437 protected:
GetFuncName() const438 StringRef GetFuncName() const override { return kCInterfaceJITExecute; }
439
GetFuncType() const440 Type GetFuncType() const override {
441 auto i64_ty = IntegerType::get(getContext(), 64);
442 auto void_ptr_ty = getVoidPtrType();
443 return LLVM::LLVMFunctionType::get(getVoidType(),
444 {/*void* op_kernel_ctx*/ void_ptr_ty,
445 /*void* callable*/ void_ptr_ty,
446 /*void* result*/ void_ptr_ty,
447 /*int64_t num_args*/ i64_ty,
448 /*void* args_ptr*/ void_ptr_ty});
449 }
450 };
451
452 class ReportErrorOpConverter
453 : public ConvertToLLVMCallOpPattern<ReportErrorOp> {
454 public:
455 using ConvertToLLVMCallOpPattern<ReportErrorOp>::ConvertToLLVMCallOpPattern;
456
matchAndRewrite(ReportErrorOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const457 LogicalResult matchAndRewrite(
458 ReportErrorOp op, ArrayRef<Value> operands,
459 ConversionPatternRewriter &rewriter) const override {
460 ReportErrorOp::Adaptor transformed(operands,
461 op.getOperation()->getAttrDictionary());
462
463 Location loc = op.getLoc();
464 auto module = op->getParentOfType<ModuleOp>();
465 Value message_constant = GenerateErrorMessageConstant(
466 loc, module, transformed.msg().getValue(), rewriter);
467
468 // Insert function call.
469 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
470 Value error_code = rewriter.create<LLVM::ConstantOp>(
471 loc, typeConverter->convertType(rewriter.getI32Type()),
472 transformed.error_code());
473 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
474 op, llvm::None, tf_func_ref,
475 llvm::makeArrayRef({transformed.ctx(), error_code, message_constant}));
476 return success();
477 }
478
479 protected:
GetFuncName() const480 StringRef GetFuncName() const override { return kCInterfaceReportError; }
GetFuncType() const481 Type GetFuncType() const override {
482 MLIRContext *ctx = &getTypeConverter()->getContext();
483 auto i8_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
484 auto i32_type = IntegerType::get(ctx, 32);
485 return LLVM::LLVMFunctionType::get(
486 getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type});
487 }
488
489 private:
490 // Generates an LLVM IR dialect global that contains the name of the given
491 // kernel function as a C string, and returns a pointer to its beginning.
GenerateErrorMessageConstant(Location loc,Operation * module,StringRef message,OpBuilder & builder) const492 Value GenerateErrorMessageConstant(Location loc, Operation *module,
493 StringRef message,
494 OpBuilder &builder) const {
495 std::string err_str;
496 llvm::raw_string_ostream err_stream(err_str);
497 err_stream << message;
498 if (!loc.isa<UnknownLoc>()) {
499 err_stream << " at ";
500 loc.print(err_stream);
501 }
502 err_stream << '\00';
503 StringRef generated_error(err_stream.str());
504 return CreateOrFindGlobalStringConstant(
505 loc, builder, kErrorMessageGlobalBaseName, generated_error);
506 }
507 };
508
509 class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
510 public:
511 using ConvertOpToLLVMPattern<NullContextOp>::ConvertOpToLLVMPattern;
512
matchAndRewrite(NullContextOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const513 LogicalResult matchAndRewrite(
514 NullContextOp op, ArrayRef<Value> operands,
515 ConversionPatternRewriter &rewriter) const override {
516 rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
517 return success();
518 }
519 };
520
521 class NullMemRefOpConverter : public ConvertOpToLLVMPattern<NullMemRefOp> {
522 public:
523 using ConvertOpToLLVMPattern<NullMemRefOp>::ConvertOpToLLVMPattern;
524
matchAndRewrite(NullMemRefOp null_memref_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const525 LogicalResult matchAndRewrite(
526 NullMemRefOp null_memref_op, ArrayRef<Value> operands,
527 ConversionPatternRewriter &rewriter) const override {
528 Location loc = null_memref_op->getLoc();
529 LLVMTypeConverter type_converter = *getTypeConverter();
530 mlir::Operation *op = null_memref_op.getOperation();
531
532 auto shaped_result_type = null_memref_op.getType().cast<BaseMemRefType>();
533 unsigned address_space = shaped_result_type.getMemorySpaceAsInt();
534
535 Type elem_type = shaped_result_type.getElementType();
536 Type llvm_elem_type = type_converter.convertType(elem_type);
537
538 Value zero = createIndexConstant(rewriter, loc, 0);
539 if (auto result_type = null_memref_op.getType().dyn_cast<MemRefType>()) {
540 // Set all dynamic sizes to 1 and compute fake strides.
541 SmallVector<Value, 4> dyn_sizes(result_type.getNumDynamicDims(),
542 createIndexConstant(rewriter, loc, 1));
543 SmallVector<Value, 4> sizes, strides;
544 Value sizeBytes;
545 getMemRefDescriptorSizes(loc, result_type, dyn_sizes, rewriter, sizes,
546 strides, sizeBytes);
547
548 // Prepare packed args [allocatedPtr, alignedPtr, offset, sizes, strides]
549 // to create a memref descriptor.
550 Value null = rewriter.create<LLVM::NullOp>(
551 loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
552 SmallVector<Value, 12> packed_values{null, null, zero};
553 packed_values.append(sizes);
554 packed_values.append(strides);
555
556 rewriter.replaceOp(
557 op, MemRefDescriptor::pack(rewriter, loc, type_converter, result_type,
558 packed_values));
559 return success();
560 }
561
562 auto result_type = null_memref_op.getType().cast<UnrankedMemRefType>();
563 Type llvm_result_type = type_converter.convertType(result_type);
564
565 auto desc =
566 UnrankedMemRefDescriptor::undef(rewriter, loc, llvm_result_type);
567 desc.setRank(rewriter, loc, zero);
568
569 // Due to the current way of handling unranked memref results escaping, we
570 // have to actually construct a ranked underlying descriptor instead of just
571 // setting its pointer to NULL.
572 SmallVector<Value, 4> sizes;
573 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
574 desc, sizes);
575 Value underlying_desc_ptr = rewriter.create<LLVM::AllocaOp>(
576 loc, getVoidPtrType(), sizes.front(), llvm::None);
577
578 // Populate underlying ranked descriptor.
579 Type elem_ptr_ptr_type = LLVM::LLVMPointerType::get(
580 LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
581
582 Value null = rewriter.create<LLVM::NullOp>(
583 loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
584 UnrankedMemRefDescriptor::setAllocatedPtr(
585 rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, null);
586 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
587 underlying_desc_ptr,
588 elem_ptr_ptr_type, null);
589 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
590 underlying_desc_ptr, elem_ptr_ptr_type,
591 zero);
592
593 desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr);
594 rewriter.replaceOp(op, {desc});
595 return success();
596 }
597 };
598
599 class IsValidMemRefOpConverter
600 : public ConvertOpToLLVMPattern<IsValidMemRefOp> {
601 public:
602 using ConvertOpToLLVMPattern<IsValidMemRefOp>::ConvertOpToLLVMPattern;
603
matchAndRewrite(IsValidMemRefOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const604 LogicalResult matchAndRewrite(
605 IsValidMemRefOp op, ArrayRef<Value> operands,
606 ConversionPatternRewriter &rewriter) const override {
607 Location loc = op.getLoc();
608 MemRefDescriptor desc(IsValidMemRefOp::Adaptor(operands).arg());
609
610 // Compare every size in the descriptor to 0 to check num_elements == 0.
611 int64_t rank = op.arg().getType().cast<MemRefType>().getRank();
612 Value is_empty_shape = rewriter.create<LLVM::ConstantOp>(
613 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
614 Value zero = createIndexConstant(rewriter, loc, 0);
615 for (int i = 0; i < rank; ++i) {
616 Value size = desc.size(rewriter, loc, i);
617 Value is_zero_size = rewriter.create<LLVM::ICmpOp>(
618 loc, rewriter.getI1Type(), LLVM::ICmpPredicate::eq, size, zero);
619 is_empty_shape =
620 rewriter.create<LLVM::OrOp>(loc, is_empty_shape, is_zero_size);
621 }
622
623 Value ptr = rewriter.create<LLVM::BitcastOp>(
624 loc, getVoidPtrType(), desc.allocatedPtr(rewriter, loc));
625 Value null = rewriter.create<LLVM::NullOp>(loc, getVoidPtrType());
626 Value is_not_nullptr = rewriter.create<LLVM::ICmpOp>(
627 loc, rewriter.getI1Type(), LLVM::ICmpPredicate::ne, ptr, null);
628
629 // Valid memref = ptr != NULL || num_elements == 0;
630 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, is_not_nullptr, is_empty_shape);
631 return success();
632 }
633 };
634
635 } // namespace
636
PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter * converter,RewritePatternSet * patterns)637 void PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter *converter,
638 RewritePatternSet *patterns) {
639 // clang-format off
640 patterns->insert<
641 IsValidMemRefOpConverter,
642 JITCompileFromStrOpConverter,
643 JITExecuteOpConverter,
644 NullContextOpConverter,
645 NullMemRefOpConverter,
646 ReportErrorOpConverter,
647 TFAllocOpConverter,
648 TFDeallocOpConverter>(*converter);
649 // clang-format on
650 }
651
652 } // namespace tf_framework
653 } // namespace kernel_gen
654 } // namespace mlir
655