1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/Support/FormatVariadic.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
19 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
26 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
27
28 namespace mlir {
29 namespace kernel_gen {
30 namespace tf_framework {
31 namespace {
32
33 using LLVM::LLVMFuncOp;
34
35 static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
36 static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
37 static constexpr StringRef kCInterfaceReportError =
38 "_mlir_ciface_tf_report_error";
39
40 /// Base class for patterns converting TF Framework ops to function calls.
41 template <typename OpTy>
42 class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
43 public:
44 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
45
46 // Attempts to find function symbol in the module, adds it if not found.
getOrInsertTFFunction(PatternRewriter & rewriter,Operation * op) const47 FlatSymbolRefAttr getOrInsertTFFunction(PatternRewriter &rewriter,
48 Operation *op) const {
49 ModuleOp module = op->getParentOfType<ModuleOp>();
50 StringRef tf_func_name = GetFuncName();
51 auto tf_func = module.lookupSymbol<LLVMFuncOp>(tf_func_name);
52 if (!tf_func) {
53 OpBuilder::InsertionGuard guard(rewriter);
54 rewriter.setInsertionPointToStart(module.getBody());
55 auto func_type = GetFuncType();
56 tf_func = rewriter.create<LLVMFuncOp>(rewriter.getUnknownLoc(),
57 tf_func_name, func_type);
58 }
59 return SymbolRefAttr::get(rewriter.getContext(), tf_func_name);
60 }
61
62 protected:
63 virtual StringRef GetFuncName() const = 0;
64 virtual Type GetFuncType() const = 0;
65 };
66
67 class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
68 public:
69 using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
70
matchAndRewrite(TFAllocOp tf_alloc_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const71 LogicalResult matchAndRewrite(
72 TFAllocOp tf_alloc_op, ArrayRef<Value> operands,
73 ConversionPatternRewriter &rewriter) const override {
74 mlir::Operation *op = tf_alloc_op.getOperation();
75 Location loc = op->getLoc();
76 TFAllocOp::Adaptor transformed(operands);
77
78 MemRefType memref_type = tf_alloc_op.getType();
79
80 // Get memref descriptor sizes.
81 SmallVector<Value, 4> sizes;
82 SmallVector<Value, 4> strides;
83 Value sizeBytes;
84 getMemRefDescriptorSizes(loc, memref_type,
85 llvm::to_vector<4>(transformed.dyn_sizes()),
86 rewriter, sizes, strides, sizeBytes);
87 // Get number of elements.
88 Value num_elements = getNumElements(loc, sizes, rewriter);
89 // Get element size.
90 Value element_size =
91 getSizeInBytes(loc, memref_type.getElementType(), rewriter);
92
93 // Convert `output_index` or set it to -1 if the attribute is missing.
94 Type llvmInt32Type = IntegerType::get(rewriter.getContext(), 32);
95 Value output_index = rewriter.create<LLVM::ConstantOp>(
96 loc, llvmInt32Type,
97 rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
98 ? tf_alloc_op.output_index().getValue()
99 : -1));
100
101 // Convert `candidate_input_indices`.
102 auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray(
103 loc, tf_alloc_op.input_indices(), &rewriter);
104
105 // Insert function call.
106 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
107 Value allocated_byte_ptr =
108 rewriter
109 .create<LLVM::CallOp>(
110 loc, getVoidPtrType(), tf_func_ref,
111 llvm::makeArrayRef({transformed.ctx(), num_elements,
112 element_size, output_index,
113 candidates_count_and_ptr.first,
114 candidates_count_and_ptr.second}))
115 .getResult(0);
116
117 MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
118 loc, rewriter, memref_type, allocated_byte_ptr, sizes);
119
120 // Return the final value of the descriptor.
121 rewriter.replaceOp(op, {memRefDescriptor});
122 return success();
123 }
124
125 protected:
GetFuncName() const126 StringRef GetFuncName() const override { return kCInterfaceAlloc; }
127
GetFuncType() const128 Type GetFuncType() const override {
129 Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
130 Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
131 Type llvm_void_ptr_type = getVoidPtrType();
132 return LLVM::LLVMFunctionType::get(
133 llvm_void_ptr_type,
134 llvm::makeArrayRef(
135 {/*void* op_kernel_ctx*/ llvm_void_ptr_type,
136 /*size_t num_elements*/ getIndexType(),
137 /*size_t element_size*/ getIndexType(),
138 /*int32_t output_index*/ llvm_i32_type,
139 /*int32_t num_candidates*/ llvm_i32_type,
140 /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}));
141 }
142
143 private:
CreateMemRefDescriptor(Location loc,ConversionPatternRewriter & rewriter,MemRefType memref_type,Value allocated_byte_ptr,ArrayRef<Value> sizes) const144 MemRefDescriptor CreateMemRefDescriptor(Location loc,
145 ConversionPatternRewriter &rewriter,
146 MemRefType memref_type,
147 Value allocated_byte_ptr,
148 ArrayRef<Value> sizes) const {
149 auto memref_desc = MemRefDescriptor::undef(
150 rewriter, loc, typeConverter->convertType(memref_type));
151
152 // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr.
153 Value allocated_type_ptr = rewriter.create<LLVM::BitcastOp>(
154 loc, getElementPtrType(memref_type), allocated_byte_ptr);
155 memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr);
156 memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr);
157 memref_desc.setConstantOffset(rewriter, loc, 0);
158
159 if (memref_type.getRank() == 0) {
160 return memref_desc;
161 }
162
163 // Compute strides and populate descriptor `size` and `stride` fields.
164 Value stride_carried = createIndexConstant(rewriter, loc, 1);
165 for (int pos = sizes.size() - 1; pos >= 0; --pos) {
166 Value size = sizes[pos];
167 memref_desc.setSize(rewriter, loc, pos, size);
168 memref_desc.setStride(rewriter, loc, pos, stride_carried);
169 // Update stride
170 if (pos > 0) {
171 stride_carried =
172 rewriter.create<LLVM::MulOp>(loc, stride_carried, size);
173 }
174 }
175 return memref_desc;
176 }
177
ConvertI32ArrayAttrToStackAllocatedArray(Location loc,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const178 std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray(
179 Location loc, llvm::Optional<ArrayAttr> attr,
180 ConversionPatternRewriter *rewriter) const {
181 Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
182 Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
183
184 // If the attribute is missing or empty, set the element count to 0 and
185 // return NULL.
186 if (!attr.hasValue() || attr.getValue().empty()) {
187 Value zero = rewriter->create<LLVM::ConstantOp>(
188 loc, llvm_i32_type, rewriter->getI32IntegerAttr(0));
189 Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type);
190 return std::make_pair(zero, null_ptr);
191 }
192
193 // Allocate array to store the elements.
194 auto &array_attr = attr.getValue();
195 Value array_size = rewriter->create<LLVM::ConstantOp>(
196 loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size()));
197 Value array_ptr = rewriter->create<LLVM::AllocaOp>(
198 loc, llvm_i32_ptr_type, array_size, /*alignment=*/0);
199
200 for (auto &dim : llvm::enumerate(array_attr)) {
201 Value index = rewriter->create<LLVM::ConstantOp>(
202 loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index()));
203 Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type,
204 array_ptr, index);
205 Value elem = rewriter->create<LLVM::ConstantOp>(
206 loc, llvm_i32_type,
207 rewriter->getI32IntegerAttr(
208 dim.value().cast<IntegerAttr>().getInt()));
209 rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr);
210 }
211 return std::make_pair(array_size, array_ptr);
212 }
213 };
214
215 class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
216 public:
217 using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
218
matchAndRewrite(TFDeallocOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const219 LogicalResult matchAndRewrite(
220 TFDeallocOp op, ArrayRef<Value> operands,
221 ConversionPatternRewriter &rewriter) const override {
222 TFDeallocOp::Adaptor transformed(operands);
223 MemRefDescriptor memref(transformed.memref());
224
225 Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
226 op.getLoc(), getVoidPtrType(),
227 memref.allocatedPtr(rewriter, op.getLoc()));
228
229 // Insert function call.
230 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
231 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
232 op, llvm::None, tf_func_ref,
233 llvm::makeArrayRef({transformed.ctx(), allocated_bytes_ptr}));
234 return success();
235 }
236
237 protected:
GetFuncName() const238 StringRef GetFuncName() const override { return kCInterfaceDealloc; }
GetFuncType() const239 Type GetFuncType() const override {
240 return LLVM::LLVMFunctionType::get(getVoidType(),
241 {getVoidPtrType(), getVoidPtrType()});
242 }
243 };
244
245 class ReportErrorOpConverter
246 : public ConvertToLLVMCallOpPattern<ReportErrorOp> {
247 public:
248 using ConvertToLLVMCallOpPattern<ReportErrorOp>::ConvertToLLVMCallOpPattern;
249
matchAndRewrite(ReportErrorOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const250 LogicalResult matchAndRewrite(
251 ReportErrorOp op, ArrayRef<Value> operands,
252 ConversionPatternRewriter &rewriter) const override {
253 ReportErrorOp::Adaptor transformed(operands,
254 op.getOperation()->getAttrDictionary());
255
256 Location loc = op.getLoc();
257 auto module = op->getParentOfType<ModuleOp>();
258 Value message_constant = GenerateErrorMessageConstant(
259 loc, module, transformed.msg().getValue(), rewriter);
260
261 // Insert function call.
262 FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
263 Value error_code = rewriter.create<LLVM::ConstantOp>(
264 loc, typeConverter->convertType(rewriter.getI32Type()),
265 transformed.error_code());
266 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
267
268 op, llvm::None, tf_func_ref,
269 llvm::makeArrayRef({transformed.ctx(), error_code, message_constant}));
270 return success();
271 }
272
273 protected:
GetFuncName() const274 StringRef GetFuncName() const override { return kCInterfaceReportError; }
GetFuncType() const275 Type GetFuncType() const override {
276 MLIRContext *ctx = &getTypeConverter()->getContext();
277 auto i8_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
278 auto i32_type = IntegerType::get(ctx, 32);
279 return LLVM::LLVMFunctionType::get(
280 getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type});
281 }
282
283 private:
284 // Generates an LLVM IR dialect global that contains the name of the given
285 // kernel function as a C string, and returns a pointer to its beginning.
GenerateErrorMessageConstant(Location loc,Operation * module,StringRef message,OpBuilder & builder) const286 Value GenerateErrorMessageConstant(Location loc, Operation *module,
287 StringRef message,
288 OpBuilder &builder) const {
289 std::string loc_str;
290 llvm::raw_string_ostream loc_stream(loc_str);
291 loc_stream << message << " at ";
292 loc.print(loc_stream);
293
294 StringRef generated_error(loc_stream.str().c_str());
295
296 std::string global_name =
297 llvm::formatv("error_message_{0}", llvm::hash_value(generated_error));
298
299 Operation *global_constant =
300 SymbolTable::lookupNearestSymbolFrom(module, global_name);
301
302 if (global_constant) {
303 Value globalPtr = builder.create<LLVM::AddressOfOp>(
304 loc, cast<LLVM::GlobalOp>(global_constant));
305
306 MLIRContext *ctx = &getTypeConverter()->getContext();
307 Value c0 = builder.create<LLVM::ConstantOp>(
308 loc, IntegerType::get(ctx, 64),
309 builder.getIntegerAttr(builder.getIndexType(), 0));
310 return builder.create<LLVM::GEPOp>(
311 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
312 ValueRange{c0, c0});
313 }
314 return LLVM::createGlobalString(loc, builder, global_name, generated_error,
315 LLVM::Linkage::Internal);
316 }
317 };
318
319 class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
320 public:
321 using ConvertOpToLLVMPattern<NullContextOp>::ConvertOpToLLVMPattern;
322
matchAndRewrite(NullContextOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const323 LogicalResult matchAndRewrite(
324 NullContextOp op, ArrayRef<Value> operands,
325 ConversionPatternRewriter &rewriter) const override {
326 rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
327 return success();
328 }
329 };
330
331 class NullMemRefOpConverter : public ConvertOpToLLVMPattern<NullMemRefOp> {
332 public:
333 using ConvertOpToLLVMPattern<NullMemRefOp>::ConvertOpToLLVMPattern;
334
matchAndRewrite(NullMemRefOp null_memref_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const335 LogicalResult matchAndRewrite(
336 NullMemRefOp null_memref_op, ArrayRef<Value> operands,
337 ConversionPatternRewriter &rewriter) const override {
338 mlir::Operation *op = null_memref_op.getOperation();
339
340 Location loc = op->getLoc();
341 auto result_type = null_memref_op.getType().cast<UnrankedMemRefType>();
342 Type llvm_result_type = typeConverter->convertType(result_type);
343
344 auto desc =
345 UnrankedMemRefDescriptor::undef(rewriter, loc, llvm_result_type);
346 Value zero = createIndexConstant(rewriter, loc, 0);
347 desc.setRank(rewriter, loc, zero);
348
349 // Due to the current way of handling unranked memref results escaping, we
350 // have to actually construct a ranked underlying descriptor instead of just
351 // setting its pointer to NULL.
352 SmallVector<Value, 4> sizes;
353 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
354 desc, sizes);
355 Value underlying_desc_ptr = rewriter.create<LLVM::AllocaOp>(
356 loc, getVoidPtrType(), sizes.front(), llvm::None);
357
358 // Populate underlying ranked descriptor.
359 unsigned address_space = result_type.getMemorySpace();
360 Type elem_type = result_type.getElementType();
361 Type llvm_elem_type = typeConverter->convertType(elem_type);
362 Type elem_ptr_ptr_type = LLVM::LLVMPointerType::get(
363 LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
364
365 auto nullPtr = rewriter.create<LLVM::NullOp>(
366 loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
367 UnrankedMemRefDescriptor::setAllocatedPtr(
368 rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, nullPtr);
369 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
370 underlying_desc_ptr,
371 elem_ptr_ptr_type, nullPtr);
372 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
373 underlying_desc_ptr, elem_ptr_ptr_type,
374 zero);
375
376 desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr);
377 rewriter.replaceOp(op, {desc});
378 return success();
379 }
380 };
381
382 } // namespace
383
PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter * converter,OwningRewritePatternList * patterns)384 void PopulateTFFrameworkToLLVMConversionPatterns(
385 LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
386 // clang-format off
387 patterns->insert<
388 NullContextOpConverter,
389 NullMemRefOpConverter,
390 ReportErrorOpConverter,
391 TFAllocOpConverter,
392 TFDeallocOpConverter
393 >(*converter);
394 // clang-format on
395 }
396
397 } // namespace tf_framework
398 } // namespace kernel_gen
399 } // namespace mlir
400