1 //===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
16 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/Support/LogicalResult.h"
27 #include "mlir/Support/MathExtras.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "mlir/Transforms/Utils.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include <functional>
38
39 using namespace mlir;
40
41 #define PASS_NAME "convert-std-to-llvm"
42
43 // Extract an LLVM IR type from the LLVM IR dialect type.
unwrap(Type type)44 static LLVM::LLVMType unwrap(Type type) {
45 if (!type)
46 return nullptr;
47 auto *mlirContext = type.getContext();
48 auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
49 if (!wrappedLLVMType)
50 emitError(UnknownLoc::get(mlirContext),
51 "conversion resulted in a non-LLVM type");
52 return wrappedLLVMType;
53 }
54
55 /// Callback to convert function argument types. It converts a MemRef function
56 /// argument to a list of non-aggregate types containing descriptor
57 /// information, and an UnrankedmemRef function argument to a list containing
58 /// the rank and a pointer to a descriptor struct.
structFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)59 LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
60 Type type,
61 SmallVectorImpl<Type> &result) {
62 if (auto memref = type.dyn_cast<MemRefType>()) {
63 // In signatures, Memref descriptors are expanded into lists of
64 // non-aggregate values.
65 auto converted =
66 converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
67 if (converted.empty())
68 return failure();
69 result.append(converted.begin(), converted.end());
70 return success();
71 }
72 if (type.isa<UnrankedMemRefType>()) {
73 auto converted = converter.getUnrankedMemRefDescriptorFields();
74 if (converted.empty())
75 return failure();
76 result.append(converted.begin(), converted.end());
77 return success();
78 }
79 auto converted = converter.convertType(type);
80 if (!converted)
81 return failure();
82 result.push_back(converted);
83 return success();
84 }
85
86 /// Callback to convert function argument types. It converts MemRef function
87 /// arguments to bare pointers to the MemRef element type.
barePtrFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)88 LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
89 Type type,
90 SmallVectorImpl<Type> &result) {
91 auto llvmTy = converter.convertCallingConventionType(type);
92 if (!llvmTy)
93 return failure();
94
95 result.push_back(llvmTy);
96 return success();
97 }
98
99 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx)100 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
101 : LLVMTypeConverter(ctx, LowerToLLVMOptions::getDefaultOptions()) {}
102
103 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx,const LowerToLLVMOptions & options)104 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
105 const LowerToLLVMOptions &options)
106 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
107 options(options) {
108 assert(llvmDialect && "LLVM IR dialect is not registered");
109 if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
110 this->options.indexBitwidth = options.dataLayout.getPointerSizeInBits();
111
112 // Register conversions for the builtin types.
113 addConversion([&](ComplexType type) { return convertComplexType(type); });
114 addConversion([&](FloatType type) { return convertFloatType(type); });
115 addConversion([&](FunctionType type) { return convertFunctionType(type); });
116 addConversion([&](IndexType type) { return convertIndexType(type); });
117 addConversion([&](IntegerType type) { return convertIntegerType(type); });
118 addConversion([&](MemRefType type) { return convertMemRefType(type); });
119 addConversion(
120 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
121 addConversion([&](VectorType type) { return convertVectorType(type); });
122
123 // LLVMType is legal, so add a pass-through conversion.
124 addConversion([](LLVM::LLVMType type) { return type; });
125
126 // Materialization for memrefs creates descriptor structs from individual
127 // values constituting them, when descriptors are used, i.e. more than one
128 // value represents a memref.
129 addArgumentMaterialization(
130 [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
131 Location loc) -> Optional<Value> {
132 if (inputs.size() == 1)
133 return llvm::None;
134 return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
135 inputs);
136 });
137 addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
138 ValueRange inputs,
139 Location loc) -> Optional<Value> {
140 if (inputs.size() == 1)
141 return llvm::None;
142 return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
143 });
144 // Add generic source and target materializations to handle cases where
145 // non-LLVM types persist after an LLVM conversion.
146 addSourceMaterialization([&](OpBuilder &builder, Type resultType,
147 ValueRange inputs,
148 Location loc) -> Optional<Value> {
149 if (inputs.size() != 1)
150 return llvm::None;
151 // FIXME: These should check LLVM::DialectCastOp can actually be constructed
152 // from the input and result.
153 return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
154 .getResult();
155 });
156 addTargetMaterialization([&](OpBuilder &builder, Type resultType,
157 ValueRange inputs,
158 Location loc) -> Optional<Value> {
159 if (inputs.size() != 1)
160 return llvm::None;
161 // FIXME: These should check LLVM::DialectCastOp can actually be constructed
162 // from the input and result.
163 return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
164 .getResult();
165 });
166 }
167
168 /// Returns the MLIR context.
getContext()169 MLIRContext &LLVMTypeConverter::getContext() {
170 return *getDialect()->getContext();
171 }
172
getIndexType()173 LLVM::LLVMType LLVMTypeConverter::getIndexType() {
174 return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth());
175 }
176
getPointerBitwidth(unsigned addressSpace)177 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
178 return options.dataLayout.getPointerSizeInBits(addressSpace);
179 }
180
convertIndexType(IndexType type)181 Type LLVMTypeConverter::convertIndexType(IndexType type) {
182 return getIndexType();
183 }
184
convertIntegerType(IntegerType type)185 Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
186 return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth());
187 }
188
convertFloatType(FloatType type)189 Type LLVMTypeConverter::convertFloatType(FloatType type) {
190 if (type.isa<Float32Type>())
191 return LLVM::LLVMType::getFloatTy(&getContext());
192 if (type.isa<Float64Type>())
193 return LLVM::LLVMType::getDoubleTy(&getContext());
194 if (type.isa<Float16Type>())
195 return LLVM::LLVMType::getHalfTy(&getContext());
196 if (type.isa<BFloat16Type>())
197 return LLVM::LLVMType::getBFloatTy(&getContext());
198 llvm_unreachable("non-float type in convertFloatType");
199 }
200
201 // Convert a `ComplexType` to an LLVM type. The result is a complex number
202 // struct with entries for the
203 // 1. real part and for the
204 // 2. imaginary part.
205 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
206 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
convertComplexType(ComplexType type)207 Type LLVMTypeConverter::convertComplexType(ComplexType type) {
208 auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
209 return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType});
210 }
211
212 // Except for signatures, MLIR function types are converted into LLVM
213 // pointer-to-function types.
convertFunctionType(FunctionType type)214 Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
215 SignatureConversion conversion(type.getNumInputs());
216 LLVM::LLVMType converted =
217 convertFunctionSignature(type, /*isVariadic=*/false, conversion);
218 return converted.getPointerTo();
219 }
220
221
222 // Function types are converted to LLVM Function types by recursively converting
223 // argument and result types. If MLIR Function has zero results, the LLVM
224 // Function has one VoidType result. If MLIR Function has more than one result,
225 // they are into an LLVM StructType in their order of appearance.
convertFunctionSignature(FunctionType funcTy,bool isVariadic,LLVMTypeConverter::SignatureConversion & result)226 LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
227 FunctionType funcTy, bool isVariadic,
228 LLVMTypeConverter::SignatureConversion &result) {
229 // Select the argument converter depending on the calling convention.
230 auto funcArgConverter = options.useBarePtrCallConv
231 ? barePtrFuncArgTypeConverter
232 : structFuncArgTypeConverter;
233 // Convert argument types one by one and check for errors.
234 for (auto &en : llvm::enumerate(funcTy.getInputs())) {
235 Type type = en.value();
236 SmallVector<Type, 8> converted;
237 if (failed(funcArgConverter(*this, type, converted)))
238 return {};
239 result.addInputs(en.index(), converted);
240 }
241
242 SmallVector<LLVM::LLVMType, 8> argTypes;
243 argTypes.reserve(llvm::size(result.getConvertedTypes()));
244 for (Type type : result.getConvertedTypes())
245 argTypes.push_back(unwrap(type));
246
247 // If function does not return anything, create the void result type,
248 // if it returns on element, convert it, otherwise pack the result types into
249 // a struct.
250 LLVM::LLVMType resultType =
251 funcTy.getNumResults() == 0
252 ? LLVM::LLVMType::getVoidTy(&getContext())
253 : unwrap(packFunctionResults(funcTy.getResults()));
254 if (!resultType)
255 return {};
256 return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
257 }
258
259 /// Converts the function type to a C-compatible format, in particular using
260 /// pointers to memref descriptors for arguments.
261 LLVM::LLVMType
convertFunctionTypeCWrapper(FunctionType type)262 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
263 SmallVector<LLVM::LLVMType, 4> inputs;
264
265 for (Type t : type.getInputs()) {
266 auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
267 if (!converted)
268 return {};
269 if (t.isa<MemRefType, UnrankedMemRefType>())
270 converted = converted.getPointerTo();
271 inputs.push_back(converted);
272 }
273
274 LLVM::LLVMType resultType =
275 type.getNumResults() == 0
276 ? LLVM::LLVMType::getVoidTy(&getContext())
277 : unwrap(packFunctionResults(type.getResults()));
278 if (!resultType)
279 return {};
280
281 return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
282 }
283
284 static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
285 static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
286 static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
287 static constexpr unsigned kSizePosInMemRefDescriptor = 3;
288 static constexpr unsigned kStridePosInMemRefDescriptor = 4;
289
290 /// Convert a memref type into a list of LLVM IR types that will form the
291 /// memref descriptor. The result contains the following types:
292 /// 1. The pointer to the allocated data buffer, followed by
293 /// 2. The pointer to the aligned data buffer, followed by
294 /// 3. A lowered `index`-type integer containing the distance between the
295 /// beginning of the buffer and the first element to be accessed through the
296 /// view, followed by
297 /// 4. An array containing as many `index`-type integers as the rank of the
298 /// MemRef: the array represents the size, in number of elements, of the memref
299 /// along the given dimension. For constant MemRef dimensions, the
300 /// corresponding size entry is a constant whose runtime value must match the
301 /// static value, followed by
302 /// 5. A second array containing as many `index`-type integers as the rank of
303 /// the MemRef: the second array represents the "stride" (in tensor abstraction
304 /// sense), i.e. the number of consecutive elements of the underlying buffer.
305 /// TODO: add assertions for the static cases.
306 ///
307 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
308 /// are expanded into individual index-type elements.
309 ///
310 /// template <typename Elem, typename Index, size_t Rank>
311 /// struct {
312 /// Elem *allocatedPtr;
313 /// Elem *alignedPtr;
314 /// Index offset;
315 /// Index sizes[Rank]; // omitted when rank == 0
316 /// Index strides[Rank]; // omitted when rank == 0
317 /// };
318 SmallVector<LLVM::LLVMType, 5>
getMemRefDescriptorFields(MemRefType type,bool unpackAggregates)319 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
320 bool unpackAggregates) {
321 assert(isStrided(type) &&
322 "Non-strided layout maps must have been normalized away");
323
324 LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
325 if (!elementType)
326 return {};
327 auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
328 auto indexTy = getIndexType();
329
330 SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
331 auto rank = type.getRank();
332 if (rank == 0)
333 return results;
334
335 if (unpackAggregates)
336 results.insert(results.end(), 2 * rank, indexTy);
337 else
338 results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank));
339 return results;
340 }
341
342 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
343 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
convertMemRefType(MemRefType type)344 Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
345 // When converting a MemRefType to a struct with descriptor fields, do not
346 // unpack the `sizes` and `strides` arrays.
347 SmallVector<LLVM::LLVMType, 5> types =
348 getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
349 return LLVM::LLVMType::getStructTy(&getContext(), types);
350 }
351
352 static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
353 static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
354
355 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
356 /// that will form the unranked memref descriptor. In particular, the fields
357 /// for an unranked memref descriptor are:
358 /// 1. index-typed rank, the dynamic rank of this MemRef
359 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
360 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
361 /// be unranked.
362 SmallVector<LLVM::LLVMType, 2>
getUnrankedMemRefDescriptorFields()363 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
364 return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
365 }
366
convertUnrankedMemRefType(UnrankedMemRefType type)367 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
368 return LLVM::LLVMType::getStructTy(&getContext(),
369 getUnrankedMemRefDescriptorFields());
370 }
371
372 /// Convert a memref type to a bare pointer to the memref element type.
convertMemRefToBarePtr(BaseMemRefType type)373 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
374 if (type.isa<UnrankedMemRefType>())
375 // Unranked memref is not supported in the bare pointer calling convention.
376 return {};
377
378 // Check that the memref has static shape, strides and offset. Otherwise, it
379 // cannot be lowered to a bare pointer.
380 auto memrefTy = type.cast<MemRefType>();
381 if (!memrefTy.hasStaticShape())
382 return {};
383
384 int64_t offset = 0;
385 SmallVector<int64_t, 4> strides;
386 if (failed(getStridesAndOffset(memrefTy, strides, offset)))
387 return {};
388
389 for (int64_t stride : strides)
390 if (ShapedType::isDynamicStrideOrOffset(stride))
391 return {};
392
393 if (ShapedType::isDynamicStrideOrOffset(offset))
394 return {};
395
396 LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
397 if (!elementType)
398 return {};
399 return elementType.getPointerTo(type.getMemorySpace());
400 }
401
402 // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
403 // n > 1.
404 // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
405 // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
convertVectorType(VectorType type)406 Type LLVMTypeConverter::convertVectorType(VectorType type) {
407 auto elementType = unwrap(convertType(type.getElementType()));
408 if (!elementType)
409 return {};
410 auto vectorType =
411 LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
412 auto shape = type.getShape();
413 for (int i = shape.size() - 2; i >= 0; --i)
414 vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
415 return vectorType;
416 }
417
418 /// Convert a type in the context of the default or bare pointer calling
419 /// convention. Calling convention sensitive types, such as MemRefType and
420 /// UnrankedMemRefType, are converted following the specific rules for the
421 /// calling convention. Calling convention independent types are converted
422 /// following the default LLVM type conversions.
convertCallingConventionType(Type type)423 Type LLVMTypeConverter::convertCallingConventionType(Type type) {
424 if (options.useBarePtrCallConv)
425 if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
426 return convertMemRefToBarePtr(memrefTy);
427
428 return convertType(type);
429 }
430
431 /// Promote the bare pointers in 'values' that resulted from memrefs to
432 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
433 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
promoteBarePtrsToDescriptors(ConversionPatternRewriter & rewriter,Location loc,ArrayRef<Type> stdTypes,SmallVectorImpl<Value> & values)434 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
435 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
436 SmallVectorImpl<Value> &values) {
437 assert(stdTypes.size() == values.size() &&
438 "The number of types and values doesn't match");
439 for (unsigned i = 0, end = values.size(); i < end; ++i)
440 if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
441 values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
442 memrefTy, values[i]);
443 }
444
ConvertToLLVMPattern(StringRef rootOpName,MLIRContext * context,LLVMTypeConverter & typeConverter,PatternBenefit benefit)445 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
446 MLIRContext *context,
447 LLVMTypeConverter &typeConverter,
448 PatternBenefit benefit)
449 : ConversionPattern(rootOpName, benefit, typeConverter, context) {}
450
451 //===----------------------------------------------------------------------===//
452 // StructBuilder implementation
453 //===----------------------------------------------------------------------===//
454
StructBuilder(Value v)455 StructBuilder::StructBuilder(Value v) : value(v) {
456 assert(value != nullptr && "value cannot be null");
457 structType = value.getType().dyn_cast<LLVM::LLVMType>();
458 assert(structType && "expected llvm type");
459 }
460
extractPtr(OpBuilder & builder,Location loc,unsigned pos)461 Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
462 unsigned pos) {
463 Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
464 return builder.create<LLVM::ExtractValueOp>(loc, type, value,
465 builder.getI64ArrayAttr(pos));
466 }
467
setPtr(OpBuilder & builder,Location loc,unsigned pos,Value ptr)468 void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
469 Value ptr) {
470 value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
471 builder.getI64ArrayAttr(pos));
472 }
473
474 //===----------------------------------------------------------------------===//
475 // ComplexStructBuilder implementation
476 //===----------------------------------------------------------------------===//
477
undef(OpBuilder & builder,Location loc,Type type)478 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
479 Location loc, Type type) {
480 Value val = builder.create<LLVM::UndefOp>(loc, type.cast<LLVM::LLVMType>());
481 return ComplexStructBuilder(val);
482 }
483
setReal(OpBuilder & builder,Location loc,Value real)484 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
485 Value real) {
486 setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
487 }
488
real(OpBuilder & builder,Location loc)489 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
490 return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
491 }
492
setImaginary(OpBuilder & builder,Location loc,Value imaginary)493 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
494 Value imaginary) {
495 setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
496 }
497
imaginary(OpBuilder & builder,Location loc)498 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
499 return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
500 }
501
502 //===----------------------------------------------------------------------===//
503 // MemRefDescriptor implementation
504 //===----------------------------------------------------------------------===//
505
506 /// Construct a helper for the given descriptor value.
MemRefDescriptor(Value descriptor)507 MemRefDescriptor::MemRefDescriptor(Value descriptor)
508 : StructBuilder(descriptor) {
509 assert(value != nullptr && "value cannot be null");
510 indexType = value.getType().cast<LLVM::LLVMType>().getStructElementType(
511 kOffsetPosInMemRefDescriptor);
512 }
513
514 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)515 MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
516 Type descriptorType) {
517
518 Value descriptor =
519 builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
520 return MemRefDescriptor(descriptor);
521 }
522
523 /// Builds IR creating a MemRef descriptor that represents `type` and
524 /// populates it with static shape and stride information extracted from the
525 /// type.
526 MemRefDescriptor
fromStaticShape(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,MemRefType type,Value memory)527 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
528 LLVMTypeConverter &typeConverter,
529 MemRefType type, Value memory) {
530 assert(type.hasStaticShape() && "unexpected dynamic shape");
531
532 // Extract all strides and offsets and verify they are static.
533 int64_t offset;
534 SmallVector<int64_t, 4> strides;
535 auto result = getStridesAndOffset(type, strides, offset);
536 (void)result;
537 assert(succeeded(result) && "unexpected failure in stride computation");
538 assert(offset != MemRefType::getDynamicStrideOrOffset() &&
539 "expected static offset");
540 assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
541 "expected static strides");
542
543 auto convertedType = typeConverter.convertType(type);
544 assert(convertedType && "unexpected failure in memref type conversion");
545
546 auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
547 descr.setAllocatedPtr(builder, loc, memory);
548 descr.setAlignedPtr(builder, loc, memory);
549 descr.setConstantOffset(builder, loc, offset);
550
551 // Fill in sizes and strides
552 for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
553 descr.setConstantSize(builder, loc, i, type.getDimSize(i));
554 descr.setConstantStride(builder, loc, i, strides[i]);
555 }
556 return descr;
557 }
558
559 /// Builds IR extracting the allocated pointer from the descriptor.
allocatedPtr(OpBuilder & builder,Location loc)560 Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
561 return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
562 }
563
564 /// Builds IR inserting the allocated pointer into the descriptor.
setAllocatedPtr(OpBuilder & builder,Location loc,Value ptr)565 void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
566 Value ptr) {
567 setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
568 }
569
570 /// Builds IR extracting the aligned pointer from the descriptor.
alignedPtr(OpBuilder & builder,Location loc)571 Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
572 return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
573 }
574
575 /// Builds IR inserting the aligned pointer into the descriptor.
setAlignedPtr(OpBuilder & builder,Location loc,Value ptr)576 void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
577 Value ptr) {
578 setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
579 }
580
581 // Creates a constant Op producing a value of `resultType` from an index-typed
582 // integer attribute.
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)583 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
584 Type resultType, int64_t value) {
585 return builder.create<LLVM::ConstantOp>(
586 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
587 }
588
589 /// Builds IR extracting the offset from the descriptor.
offset(OpBuilder & builder,Location loc)590 Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
591 return builder.create<LLVM::ExtractValueOp>(
592 loc, indexType, value,
593 builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
594 }
595
596 /// Builds IR inserting the offset into the descriptor.
setOffset(OpBuilder & builder,Location loc,Value offset)597 void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
598 Value offset) {
599 value = builder.create<LLVM::InsertValueOp>(
600 loc, structType, value, offset,
601 builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
602 }
603
604 /// Builds IR inserting the offset into the descriptor.
setConstantOffset(OpBuilder & builder,Location loc,uint64_t offset)605 void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
606 uint64_t offset) {
607 setOffset(builder, loc,
608 createIndexAttrConstant(builder, loc, indexType, offset));
609 }
610
611 /// Builds IR extracting the pos-th size from the descriptor.
size(OpBuilder & builder,Location loc,unsigned pos)612 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
613 return builder.create<LLVM::ExtractValueOp>(
614 loc, indexType, value,
615 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
616 }
617
size(OpBuilder & builder,Location loc,Value pos,int64_t rank)618 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
619 int64_t rank) {
620 auto indexTy = indexType.cast<LLVM::LLVMType>();
621 auto indexPtrTy = indexTy.getPointerTo();
622 auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
623 auto arrayPtrTy = arrayTy.getPointerTo();
624
625 // Copy size values to stack-allocated memory.
626 auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
627 auto one = createIndexAttrConstant(builder, loc, indexType, 1);
628 auto sizes = builder.create<LLVM::ExtractValueOp>(
629 loc, arrayTy, value,
630 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
631 auto sizesPtr =
632 builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
633 builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
634
635 // Load an return size value of interest.
636 auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
637 ValueRange({zero, pos}));
638 return builder.create<LLVM::LoadOp>(loc, resultPtr);
639 }
640
641 /// Builds IR inserting the pos-th size into the descriptor
setSize(OpBuilder & builder,Location loc,unsigned pos,Value size)642 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
643 Value size) {
644 value = builder.create<LLVM::InsertValueOp>(
645 loc, structType, value, size,
646 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
647 }
648
setConstantSize(OpBuilder & builder,Location loc,unsigned pos,uint64_t size)649 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
650 unsigned pos, uint64_t size) {
651 setSize(builder, loc, pos,
652 createIndexAttrConstant(builder, loc, indexType, size));
653 }
654
655 /// Builds IR extracting the pos-th stride from the descriptor.
stride(OpBuilder & builder,Location loc,unsigned pos)656 Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
657 return builder.create<LLVM::ExtractValueOp>(
658 loc, indexType, value,
659 builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
660 }
661
662 /// Builds IR inserting the pos-th stride into the descriptor
setStride(OpBuilder & builder,Location loc,unsigned pos,Value stride)663 void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
664 Value stride) {
665 value = builder.create<LLVM::InsertValueOp>(
666 loc, structType, value, stride,
667 builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
668 }
669
setConstantStride(OpBuilder & builder,Location loc,unsigned pos,uint64_t stride)670 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
671 unsigned pos, uint64_t stride) {
672 setStride(builder, loc, pos,
673 createIndexAttrConstant(builder, loc, indexType, stride));
674 }
675
getElementPtrType()676 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
677 return value.getType()
678 .cast<LLVM::LLVMType>()
679 .getStructElementType(kAlignedPtrPosInMemRefDescriptor)
680 .cast<LLVM::LLVMPointerType>();
681 }
682
683 /// Creates a MemRef descriptor structure from a list of individual values
684 /// composing that descriptor, in the following order:
685 /// - allocated pointer;
686 /// - aligned pointer;
687 /// - offset;
688 /// - <rank> sizes;
689 /// - <rank> shapes;
690 /// where <rank> is the MemRef rank as provided in `type`.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,MemRefType type,ValueRange values)691 Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
692 LLVMTypeConverter &converter, MemRefType type,
693 ValueRange values) {
694 Type llvmType = converter.convertType(type);
695 auto d = MemRefDescriptor::undef(builder, loc, llvmType);
696
697 d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
698 d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
699 d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
700
701 int64_t rank = type.getRank();
702 for (unsigned i = 0; i < rank; ++i) {
703 d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
704 d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
705 }
706
707 return d;
708 }
709
710 /// Builds IR extracting individual elements of a MemRef descriptor structure
711 /// and returning them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,MemRefType type,SmallVectorImpl<Value> & results)712 void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
713 MemRefType type,
714 SmallVectorImpl<Value> &results) {
715 int64_t rank = type.getRank();
716 results.reserve(results.size() + getNumUnpackedValues(type));
717
718 MemRefDescriptor d(packed);
719 results.push_back(d.allocatedPtr(builder, loc));
720 results.push_back(d.alignedPtr(builder, loc));
721 results.push_back(d.offset(builder, loc));
722 for (int64_t i = 0; i < rank; ++i)
723 results.push_back(d.size(builder, loc, i));
724 for (int64_t i = 0; i < rank; ++i)
725 results.push_back(d.stride(builder, loc, i));
726 }
727
728 /// Returns the number of non-aggregate values that would be produced by
729 /// `unpack`.
getNumUnpackedValues(MemRefType type)730 unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
731 // Two pointers, offset, <rank> sizes, <rank> shapes.
732 return 3 + 2 * type.getRank();
733 }
734
735 //===----------------------------------------------------------------------===//
736 // MemRefDescriptorView implementation.
737 //===----------------------------------------------------------------------===//
738
MemRefDescriptorView(ValueRange range)739 MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
740 : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
741
allocatedPtr()742 Value MemRefDescriptorView::allocatedPtr() {
743 return elements[kAllocatedPtrPosInMemRefDescriptor];
744 }
745
alignedPtr()746 Value MemRefDescriptorView::alignedPtr() {
747 return elements[kAlignedPtrPosInMemRefDescriptor];
748 }
749
offset()750 Value MemRefDescriptorView::offset() {
751 return elements[kOffsetPosInMemRefDescriptor];
752 }
753
size(unsigned pos)754 Value MemRefDescriptorView::size(unsigned pos) {
755 return elements[kSizePosInMemRefDescriptor + pos];
756 }
757
stride(unsigned pos)758 Value MemRefDescriptorView::stride(unsigned pos) {
759 return elements[kSizePosInMemRefDescriptor + rank + pos];
760 }
761
762 //===----------------------------------------------------------------------===//
763 // UnrankedMemRefDescriptor implementation
764 //===----------------------------------------------------------------------===//
765
766 /// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor(Value descriptor)767 UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
768 : StructBuilder(descriptor) {}
769
770 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)771 UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
772 Location loc,
773 Type descriptorType) {
774 Value descriptor =
775 builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
776 return UnrankedMemRefDescriptor(descriptor);
777 }
rank(OpBuilder & builder,Location loc)778 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
779 return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
780 }
setRank(OpBuilder & builder,Location loc,Value v)781 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
782 Value v) {
783 setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
784 }
memRefDescPtr(OpBuilder & builder,Location loc)785 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
786 Location loc) {
787 return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
788 }
setMemRefDescPtr(OpBuilder & builder,Location loc,Value v)789 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
790 Location loc, Value v) {
791 setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
792 }
793
794 /// Builds IR populating an unranked MemRef descriptor structure from a list
795 /// of individual constituent values in the following order:
796 /// - rank of the memref;
797 /// - pointer to the memref descriptor.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,UnrankedMemRefType type,ValueRange values)798 Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
799 LLVMTypeConverter &converter,
800 UnrankedMemRefType type,
801 ValueRange values) {
802 Type llvmType = converter.convertType(type);
803 auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
804
805 d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
806 d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
807 return d;
808 }
809
810 /// Builds IR extracting individual elements that compose an unranked memref
811 /// descriptor and returns them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,SmallVectorImpl<Value> & results)812 void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
813 Value packed,
814 SmallVectorImpl<Value> &results) {
815 UnrankedMemRefDescriptor d(packed);
816 results.reserve(results.size() + 2);
817 results.push_back(d.rank(builder, loc));
818 results.push_back(d.memRefDescPtr(builder, loc));
819 }
820
computeSizes(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,ArrayRef<UnrankedMemRefDescriptor> values,SmallVectorImpl<Value> & sizes)821 void UnrankedMemRefDescriptor::computeSizes(
822 OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
823 ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
824 if (values.empty())
825 return;
826
827 // Cache the index type.
828 LLVM::LLVMType indexType = typeConverter.getIndexType();
829
830 // Initialize shared constants.
831 Value one = createIndexAttrConstant(builder, loc, indexType, 1);
832 Value two = createIndexAttrConstant(builder, loc, indexType, 2);
833 Value pointerSize = createIndexAttrConstant(
834 builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
835 Value indexSize =
836 createIndexAttrConstant(builder, loc, indexType,
837 ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
838
839 sizes.reserve(sizes.size() + values.size());
840 for (UnrankedMemRefDescriptor desc : values) {
841 // Emit IR computing the memory necessary to store the descriptor. This
842 // assumes the descriptor to be
843 // { type*, type*, index, index[rank], index[rank] }
844 // and densely packed, so the total size is
845 // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
846 // TODO: consider including the actual size (including eventual padding due
847 // to data layout) into the unranked descriptor.
848 Value doublePointerSize =
849 builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
850
851 // (1 + 2 * rank) * sizeof(index)
852 Value rank = desc.rank(builder, loc);
853 Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
854 Value doubleRankIncremented =
855 builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
856 Value rankIndexSize = builder.create<LLVM::MulOp>(
857 loc, indexType, doubleRankIncremented, indexSize);
858
859 // Total allocation size.
860 Value allocationSize = builder.create<LLVM::AddOp>(
861 loc, indexType, doublePointerSize, rankIndexSize);
862 sizes.push_back(allocationSize);
863 }
864 }
865
allocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)866 Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
867 Value memRefDescPtr,
868 LLVM::LLVMType elemPtrPtrType) {
869
870 Value elementPtrPtr =
871 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
872 return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
873 }
874
setAllocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType,Value allocatedPtr)875 void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
876 Value memRefDescPtr,
877 LLVM::LLVMType elemPtrPtrType,
878 Value allocatedPtr) {
879 Value elementPtrPtr =
880 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
881 builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
882 }
883
alignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)884 Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
885 LLVMTypeConverter &typeConverter,
886 Value memRefDescPtr,
887 LLVM::LLVMType elemPtrPtrType) {
888 Value elementPtrPtr =
889 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
890
891 Value one =
892 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
893 Value alignedGep = builder.create<LLVM::GEPOp>(
894 loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
895 return builder.create<LLVM::LoadOp>(loc, alignedGep);
896 }
897
setAlignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType,Value alignedPtr)898 void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
899 LLVMTypeConverter &typeConverter,
900 Value memRefDescPtr,
901 LLVM::LLVMType elemPtrPtrType,
902 Value alignedPtr) {
903 Value elementPtrPtr =
904 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
905
906 Value one =
907 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
908 Value alignedGep = builder.create<LLVM::GEPOp>(
909 loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
910 builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
911 }
912
offset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)913 Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
914 LLVMTypeConverter &typeConverter,
915 Value memRefDescPtr,
916 LLVM::LLVMType elemPtrPtrType) {
917 Value elementPtrPtr =
918 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
919
920 Value two =
921 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
922 Value offsetGep = builder.create<LLVM::GEPOp>(
923 loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
924 offsetGep = builder.create<LLVM::BitcastOp>(
925 loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
926 return builder.create<LLVM::LoadOp>(loc, offsetGep);
927 }
928
setOffset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType,Value offset)929 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
930 LLVMTypeConverter &typeConverter,
931 Value memRefDescPtr,
932 LLVM::LLVMType elemPtrPtrType,
933 Value offset) {
934 Value elementPtrPtr =
935 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
936
937 Value two =
938 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
939 Value offsetGep = builder.create<LLVM::GEPOp>(
940 loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
941 offsetGep = builder.create<LLVM::BitcastOp>(
942 loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
943 builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
944 }
945
sizeBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)946 Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
947 LLVMTypeConverter &typeConverter,
948 Value memRefDescPtr,
949 LLVM::LLVMType elemPtrPtrType) {
950 LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy();
951 LLVM::LLVMType indexTy = typeConverter.getIndexType();
952 LLVM::LLVMType structPtrTy =
953 LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)
954 .getPointerTo();
955 Value structPtr =
956 builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
957
958 LLVM::LLVMType int32_type =
959 unwrap(typeConverter.convertType(builder.getI32Type()));
960 Value zero =
961 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
962 Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
963 builder.getI32IntegerAttr(3));
964 return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr,
965 ValueRange({zero, three}));
966 }
967
size(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index)968 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
969 LLVMTypeConverter typeConverter,
970 Value sizeBasePtr, Value index) {
971 LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
972 Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
973 ValueRange({index}));
974 return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
975 }
976
setSize(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index,Value size)977 void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
978 LLVMTypeConverter typeConverter,
979 Value sizeBasePtr, Value index,
980 Value size) {
981 LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
982 Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
983 ValueRange({index}));
984 builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
985 }
986
strideBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value rank)987 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
988 LLVMTypeConverter &typeConverter,
989 Value sizeBasePtr, Value rank) {
990 LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
991 return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
992 ValueRange({rank}));
993 }
994
stride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)995 Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
996 LLVMTypeConverter typeConverter,
997 Value strideBasePtr, Value index,
998 Value stride) {
999 LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
1000 Value strideStoreGep = builder.create<LLVM::GEPOp>(
1001 loc, indexPtrTy, strideBasePtr, ValueRange({index}));
1002 return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
1003 }
1004
setStride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)1005 void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
1006 LLVMTypeConverter typeConverter,
1007 Value strideBasePtr, Value index,
1008 Value stride) {
1009 LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
1010 Value strideStoreGep = builder.create<LLVM::GEPOp>(
1011 loc, indexPtrTy, strideBasePtr, ValueRange({index}));
1012 builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
1013 }
1014
getTypeConverter() const1015 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
1016 return static_cast<LLVMTypeConverter *>(
1017 ConversionPattern::getTypeConverter());
1018 }
1019
getDialect() const1020 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
1021 return *getTypeConverter()->getDialect();
1022 }
1023
getIndexType() const1024 LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
1025 return getTypeConverter()->getIndexType();
1026 }
1027
1028 LLVM::LLVMType
getIntPtrType(unsigned addressSpace) const1029 ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
1030 return LLVM::LLVMType::getIntNTy(
1031 &getTypeConverter()->getContext(),
1032 getTypeConverter()->getPointerBitwidth(addressSpace));
1033 }
1034
getVoidType() const1035 LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
1036 return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext());
1037 }
1038
getVoidPtrType() const1039 LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
1040 return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext());
1041 }
1042
createIndexConstant(ConversionPatternRewriter & builder,Location loc,uint64_t value) const1043 Value ConvertToLLVMPattern::createIndexConstant(
1044 ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
1045 return createIndexAttrConstant(builder, loc, getIndexType(), value);
1046 }
1047
getStridedElementPtr(Location loc,MemRefType type,Value memRefDesc,ValueRange indices,ConversionPatternRewriter & rewriter) const1048 Value ConvertToLLVMPattern::getStridedElementPtr(
1049 Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
1050 ConversionPatternRewriter &rewriter) const {
1051
1052 int64_t offset;
1053 SmallVector<int64_t, 4> strides;
1054 auto successStrides = getStridesAndOffset(type, strides, offset);
1055 assert(succeeded(successStrides) && "unexpected non-strided memref");
1056 (void)successStrides;
1057
1058 MemRefDescriptor memRefDescriptor(memRefDesc);
1059 Value base = memRefDescriptor.alignedPtr(rewriter, loc);
1060
1061 Value index;
1062 if (offset != 0) // Skip if offset is zero.
1063 index = offset == MemRefType::getDynamicStrideOrOffset()
1064 ? memRefDescriptor.offset(rewriter, loc)
1065 : createIndexConstant(rewriter, loc, offset);
1066
1067 for (int i = 0, e = indices.size(); i < e; ++i) {
1068 Value increment = indices[i];
1069 if (strides[i] != 1) { // Skip if stride is 1.
1070 Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
1071 ? memRefDescriptor.stride(rewriter, loc, i)
1072 : createIndexConstant(rewriter, loc, strides[i]);
1073 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
1074 }
1075 index =
1076 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
1077 }
1078
1079 LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType();
1080 return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
1081 : base;
1082 }
1083
getDataPtr(Location loc,MemRefType type,Value memRefDesc,ValueRange indices,ConversionPatternRewriter & rewriter) const1084 Value ConvertToLLVMPattern::getDataPtr(
1085 Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
1086 ConversionPatternRewriter &rewriter) const {
1087 return getStridedElementPtr(loc, type, memRefDesc, indices, rewriter);
1088 }
1089
1090 // Check if the MemRefType `type` is supported by the lowering. We currently
1091 // only support memrefs with identity maps.
isSupportedMemRefType(MemRefType type) const1092 bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
1093 if (!typeConverter->convertType(type.getElementType()))
1094 return false;
1095 return type.getAffineMaps().empty() ||
1096 llvm::all_of(type.getAffineMaps(),
1097 [](AffineMap map) { return map.isIdentity(); });
1098 }
1099
getElementPtrType(MemRefType type) const1100 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
1101 auto elementType = type.getElementType();
1102 auto structElementType = unwrap(typeConverter->convertType(elementType));
1103 return structElementType.getPointerTo(type.getMemorySpace());
1104 }
1105
getMemRefDescriptorSizes(Location loc,MemRefType memRefType,ArrayRef<Value> dynamicSizes,ConversionPatternRewriter & rewriter,SmallVectorImpl<Value> & sizes,SmallVectorImpl<Value> & strides,Value & sizeBytes) const1106 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
1107 Location loc, MemRefType memRefType, ArrayRef<Value> dynamicSizes,
1108 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
1109 SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
1110 assert(isSupportedMemRefType(memRefType) &&
1111 "layout maps must have been normalized away");
1112
1113 sizes.reserve(memRefType.getRank());
1114 unsigned dynamicIndex = 0;
1115 for (int64_t size : memRefType.getShape()) {
1116 sizes.push_back(size == ShapedType::kDynamicSize
1117 ? dynamicSizes[dynamicIndex++]
1118 : createIndexConstant(rewriter, loc, size));
1119 }
1120
1121 // Strides: iterate sizes in reverse order and multiply.
1122 int64_t stride = 1;
1123 Value runningStride = createIndexConstant(rewriter, loc, 1);
1124 strides.resize(memRefType.getRank());
1125 for (auto i = memRefType.getRank(); i-- > 0;) {
1126 strides[i] = runningStride;
1127
1128 int64_t size = memRefType.getShape()[i];
1129 if (size == 0)
1130 continue;
1131 bool useSizeAsStride = stride == 1;
1132 if (size == ShapedType::kDynamicSize)
1133 stride = ShapedType::kDynamicSize;
1134 if (stride != ShapedType::kDynamicSize)
1135 stride *= size;
1136
1137 if (useSizeAsStride)
1138 runningStride = sizes[i];
1139 else if (stride == ShapedType::kDynamicSize)
1140 runningStride =
1141 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
1142 else
1143 runningStride = createIndexConstant(rewriter, loc, stride);
1144 }
1145
1146 // Buffer size in bytes.
1147 Type elementPtrType = getElementPtrType(memRefType);
1148 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
1149 Value gepPtr = rewriter.create<LLVM::GEPOp>(
1150 loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
1151 sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1152 }
1153
getSizeInBytes(Location loc,Type type,ConversionPatternRewriter & rewriter) const1154 Value ConvertToLLVMPattern::getSizeInBytes(
1155 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
1156 // Compute the size of an individual element. This emits the MLIR equivalent
1157 // of the following sizeof(...) implementation in LLVM IR:
1158 // %0 = getelementptr %elementType* null, %indexType 1
1159 // %1 = ptrtoint %elementType* %0 to %indexType
1160 // which is a common pattern of getting the size of a type in bytes.
1161 auto convertedPtrType =
1162 typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo();
1163 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
1164 auto gep = rewriter.create<LLVM::GEPOp>(
1165 loc, convertedPtrType,
1166 ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
1167 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
1168 }
1169
getNumElements(Location loc,ArrayRef<Value> shape,ConversionPatternRewriter & rewriter) const1170 Value ConvertToLLVMPattern::getNumElements(
1171 Location loc, ArrayRef<Value> shape,
1172 ConversionPatternRewriter &rewriter) const {
1173 // Compute the total number of memref elements.
1174 Value numElements =
1175 shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
1176 for (unsigned i = 1, e = shape.size(); i < e; ++i)
1177 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
1178 return numElements;
1179 }
1180
1181 /// Creates and populates the memref descriptor struct given all its fields.
createMemRefDescriptor(Location loc,MemRefType memRefType,Value allocatedPtr,Value alignedPtr,ArrayRef<Value> sizes,ArrayRef<Value> strides,ConversionPatternRewriter & rewriter) const1182 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
1183 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
1184 ArrayRef<Value> sizes, ArrayRef<Value> strides,
1185 ConversionPatternRewriter &rewriter) const {
1186 auto structType = typeConverter->convertType(memRefType);
1187 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
1188
1189 // Field 1: Allocated pointer, used for malloc/free.
1190 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
1191
1192 // Field 2: Actual aligned pointer to payload.
1193 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
1194
1195 // Field 3: Offset in aligned pointer.
1196 memRefDescriptor.setOffset(rewriter, loc,
1197 createIndexConstant(rewriter, loc, 0));
1198
1199 // Fields 4: Sizes.
1200 for (auto en : llvm::enumerate(sizes))
1201 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
1202
1203 // Field 5: Strides.
1204 for (auto en : llvm::enumerate(strides))
1205 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
1206
1207 return memRefDescriptor;
1208 }
1209
1210 /// Only retain those attributes that are not constructed by
1211 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
1212 /// attributes.
filterFuncAttributes(ArrayRef<NamedAttribute> attrs,bool filterArgAttrs,SmallVectorImpl<NamedAttribute> & result)1213 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
1214 bool filterArgAttrs,
1215 SmallVectorImpl<NamedAttribute> &result) {
1216 for (const auto &attr : attrs) {
1217 if (attr.first == SymbolTable::getSymbolAttrName() ||
1218 attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" ||
1219 (filterArgAttrs && impl::isArgAttrName(attr.first.strref())))
1220 continue;
1221 result.push_back(attr);
1222 }
1223 }
1224
1225 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
1226 /// arguments instead of unpacked arguments. This function can be called from C
1227 /// by passing a pointer to a C struct corresponding to a memref descriptor.
1228 /// Internally, the auxiliary function unpacks the descriptor into individual
1229 /// components and forwards them to `newFuncOp`.
wrapForExternalCallers(OpBuilder & rewriter,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)1230 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
1231 LLVMTypeConverter &typeConverter,
1232 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
1233 auto type = funcOp.getType();
1234 SmallVector<NamedAttribute, 4> attributes;
1235 filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
1236 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
1237 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
1238 typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
1239 attributes);
1240
1241 OpBuilder::InsertionGuard guard(rewriter);
1242 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
1243
1244 SmallVector<Value, 8> args;
1245 for (auto &en : llvm::enumerate(type.getInputs())) {
1246 Value arg = wrapperFuncOp.getArgument(en.index());
1247 if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
1248 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
1249 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
1250 continue;
1251 }
1252 if (en.value().isa<UnrankedMemRefType>()) {
1253 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
1254 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
1255 continue;
1256 }
1257
1258 args.push_back(wrapperFuncOp.getArgument(en.index()));
1259 }
1260 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
1261 rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
1262 }
1263
1264 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
1265 /// arguments instead of unpacked arguments. Creates a body for the (external)
1266 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
1267 /// individual arguments into this descriptor and passes a pointer to it into
1268 /// the auxiliary function. This auxiliary external function is now compatible
1269 /// with functions defined in C using pointers to C structs corresponding to a
1270 /// memref descriptor.
wrapExternalFunction(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)1271 static void wrapExternalFunction(OpBuilder &builder, Location loc,
1272 LLVMTypeConverter &typeConverter,
1273 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
1274 OpBuilder::InsertionGuard guard(builder);
1275
1276 LLVM::LLVMType wrapperType =
1277 typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
1278 // This conversion can only fail if it could not convert one of the argument
1279 // types. But since it has been applies to a non-wrapper function before, it
1280 // should have failed earlier and not reach this point at all.
1281 assert(wrapperType && "unexpected type conversion failure");
1282
1283 SmallVector<NamedAttribute, 4> attributes;
1284 filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
1285
1286 // Create the auxiliary function.
1287 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
1288 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
1289 wrapperType, LLVM::Linkage::External, attributes);
1290
1291 builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
1292
1293 // Get a ValueRange containing arguments.
1294 FunctionType type = funcOp.getType();
1295 SmallVector<Value, 8> args;
1296 args.reserve(type.getNumInputs());
1297 ValueRange wrapperArgsRange(newFuncOp.getArguments());
1298
1299 // Iterate over the inputs of the original function and pack values into
1300 // memref descriptors if the original type is a memref.
1301 for (auto &en : llvm::enumerate(type.getInputs())) {
1302 Value arg;
1303 int numToDrop = 1;
1304 auto memRefType = en.value().dyn_cast<MemRefType>();
1305 auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
1306 if (memRefType || unrankedMemRefType) {
1307 numToDrop = memRefType
1308 ? MemRefDescriptor::getNumUnpackedValues(memRefType)
1309 : UnrankedMemRefDescriptor::getNumUnpackedValues();
1310 Value packed =
1311 memRefType
1312 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
1313 wrapperArgsRange.take_front(numToDrop))
1314 : UnrankedMemRefDescriptor::pack(
1315 builder, loc, typeConverter, unrankedMemRefType,
1316 wrapperArgsRange.take_front(numToDrop));
1317
1318 auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
1319 Value one = builder.create<LLVM::ConstantOp>(
1320 loc, typeConverter.convertType(builder.getIndexType()),
1321 builder.getIntegerAttr(builder.getIndexType(), 1));
1322 Value allocated =
1323 builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
1324 builder.create<LLVM::StoreOp>(loc, packed, allocated);
1325 arg = allocated;
1326 } else {
1327 arg = wrapperArgsRange[0];
1328 }
1329
1330 args.push_back(arg);
1331 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
1332 }
1333 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
1334
1335 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
1336 builder.create<LLVM::ReturnOp>(loc, call.getResults());
1337 }
1338
1339 namespace {
1340
1341 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
1342 protected:
1343 using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
1344
1345 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
1346 // to this legalization pattern.
1347 LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp__anonf5dcde620f11::FuncOpConversionBase1348 convertFuncOpToLLVMFuncOp(FuncOp funcOp,
1349 ConversionPatternRewriter &rewriter) const {
1350 // Convert the original function arguments. They are converted using the
1351 // LLVMTypeConverter provided to this legalization pattern.
1352 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
1353 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
1354 auto llvmType = getTypeConverter()->convertFunctionSignature(
1355 funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
1356 if (!llvmType)
1357 return nullptr;
1358
1359 // Propagate argument attributes to all converted arguments obtained after
1360 // converting a given original argument.
1361 SmallVector<NamedAttribute, 4> attributes;
1362 filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true,
1363 attributes);
1364 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
1365 auto attr = impl::getArgAttrDict(funcOp, i);
1366 if (!attr)
1367 continue;
1368
1369 auto mapping = result.getInputMapping(i);
1370 assert(mapping.hasValue() && "unexpected deletion of function argument");
1371
1372 SmallString<8> name;
1373 for (size_t j = 0; j < mapping->size; ++j) {
1374 impl::getArgAttrName(mapping->inputNo + j, name);
1375 attributes.push_back(rewriter.getNamedAttr(name, attr));
1376 }
1377 }
1378
1379 // Create an LLVM function, use external linkage by default until MLIR
1380 // functions have linkage.
1381 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
1382 funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External,
1383 attributes);
1384 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1385 newFuncOp.end());
1386 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
1387 &result)))
1388 return nullptr;
1389
1390 return newFuncOp;
1391 }
1392 };
1393
1394 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
1395 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
1396 /// information.
1397 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
1398 struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion__anonf5dcde620f11::FuncOpConversion1399 FuncOpConversion(LLVMTypeConverter &converter)
1400 : FuncOpConversionBase(converter) {}
1401
1402 LogicalResult
matchAndRewrite__anonf5dcde620f11::FuncOpConversion1403 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
1404 ConversionPatternRewriter &rewriter) const override {
1405 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
1406 if (!newFuncOp)
1407 return failure();
1408
1409 if (getTypeConverter()->getOptions().emitCWrappers ||
1410 funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
1411 if (newFuncOp.isExternal())
1412 wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
1413 funcOp, newFuncOp);
1414 else
1415 wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
1416 funcOp, newFuncOp);
1417 }
1418
1419 rewriter.eraseOp(funcOp);
1420 return success();
1421 }
1422 };
1423
1424 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
1425 /// to the MemRef element type. This will impact the calling convention and ABI.
1426 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
1427 using FuncOpConversionBase::FuncOpConversionBase;
1428
1429 LogicalResult
matchAndRewrite__anonf5dcde620f11::BarePtrFuncOpConversion1430 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
1431 ConversionPatternRewriter &rewriter) const override {
1432 // Store the type of memref-typed arguments before the conversion so that we
1433 // can promote them to MemRef descriptor at the beginning of the function.
1434 SmallVector<Type, 8> oldArgTypes =
1435 llvm::to_vector<8>(funcOp.getType().getInputs());
1436
1437 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
1438 if (!newFuncOp)
1439 return failure();
1440 if (newFuncOp.getBody().empty()) {
1441 rewriter.eraseOp(funcOp);
1442 return success();
1443 }
1444
1445 // Promote bare pointers from memref arguments to memref descriptors at the
1446 // beginning of the function so that all the memrefs in the function have a
1447 // uniform representation.
1448 Block *entryBlock = &newFuncOp.getBody().front();
1449 auto blockArgs = entryBlock->getArguments();
1450 assert(blockArgs.size() == oldArgTypes.size() &&
1451 "The number of arguments and types doesn't match");
1452
1453 OpBuilder::InsertionGuard guard(rewriter);
1454 rewriter.setInsertionPointToStart(entryBlock);
1455 for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
1456 BlockArgument arg = std::get<0>(it);
1457 Type argTy = std::get<1>(it);
1458
1459 // Unranked memrefs are not supported in the bare pointer calling
1460 // convention. We should have bailed out before in the presence of
1461 // unranked memrefs.
1462 assert(!argTy.isa<UnrankedMemRefType>() &&
1463 "Unranked memref is not supported");
1464 auto memrefTy = argTy.dyn_cast<MemRefType>();
1465 if (!memrefTy)
1466 continue;
1467
1468 // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
1469 // or unranked memref descriptor and replace placeholder with the last
1470 // instruction of the memref descriptor.
1471 // TODO: The placeholder is needed to avoid replacing barePtr uses in the
1472 // MemRef descriptor instructions. We may want to have a utility in the
1473 // rewriter to properly handle this use case.
1474 Location loc = funcOp.getLoc();
1475 auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
1476 rewriter.replaceUsesOfBlockArgument(arg, placeholder);
1477
1478 Value desc = MemRefDescriptor::fromStaticShape(
1479 rewriter, loc, *getTypeConverter(), memrefTy, arg);
1480 rewriter.replaceOp(placeholder, {desc});
1481 }
1482
1483 rewriter.eraseOp(funcOp);
1484 return success();
1485 }
1486 };
1487
1488 //////////////// Support for Lowering operations on n-D vectors ////////////////
1489 // Helper struct to "unroll" operations on n-D vectors in terms of operations on
1490 // 1-D LLVM vectors.
1491 struct NDVectorTypeInfo {
1492 // LLVM array struct which encodes n-D vectors.
1493 LLVM::LLVMType llvmArrayTy;
1494 // LLVM vector type which encodes the inner 1-D vector type.
1495 LLVM::LLVMType llvmVectorTy;
1496 // Multiplicity of llvmArrayTy to llvmVectorTy.
1497 SmallVector<int64_t, 4> arraySizes;
1498 };
1499 } // namespace
1500
1501 // For >1-D vector types, extracts the necessary information to iterate over all
1502 // 1-D subvectors in the underlying llrepresentation of the n-D vector
1503 // Iterates on the llvm array type until we hit a non-array type (which is
1504 // asserted to be an llvm vector type).
extractNDVectorTypeInfo(VectorType vectorType,LLVMTypeConverter & converter)1505 static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
1506 LLVMTypeConverter &converter) {
1507 assert(vectorType.getRank() > 1 && "expected >1D vector type");
1508 NDVectorTypeInfo info;
1509 info.llvmArrayTy =
1510 converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
1511 if (!info.llvmArrayTy)
1512 return info;
1513 info.arraySizes.reserve(vectorType.getRank() - 1);
1514 auto llvmTy = info.llvmArrayTy;
1515 while (llvmTy.isArrayTy()) {
1516 info.arraySizes.push_back(llvmTy.getArrayNumElements());
1517 llvmTy = llvmTy.getArrayElementType();
1518 }
1519 if (!llvmTy.isVectorTy())
1520 return info;
1521 info.llvmVectorTy = llvmTy;
1522 return info;
1523 }
1524
1525 // Express `linearIndex` in terms of coordinates of `basis`.
1526 // Returns the empty vector when linearIndex is out of the range [0, P] where
1527 // P is the product of all the basis coordinates.
1528 //
1529 // Prerequisites:
1530 // Basis is an array of nonnegative integers (signed type inherited from
1531 // vector shape type).
getCoordinates(ArrayRef<int64_t> basis,unsigned linearIndex)1532 static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
1533 unsigned linearIndex) {
1534 SmallVector<int64_t, 4> res;
1535 res.reserve(basis.size());
1536 for (unsigned basisElement : llvm::reverse(basis)) {
1537 res.push_back(linearIndex % basisElement);
1538 linearIndex = linearIndex / basisElement;
1539 }
1540 if (linearIndex > 0)
1541 return {};
1542 std::reverse(res.begin(), res.end());
1543 return res;
1544 }
1545
1546 // Iterate of linear index, convert to coords space and insert splatted 1-D
1547 // vector in each position.
1548 template <typename Lambda>
nDVectorIterate(const NDVectorTypeInfo & info,OpBuilder & builder,Lambda fun)1549 void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
1550 Lambda fun) {
1551 unsigned ub = 1;
1552 for (auto s : info.arraySizes)
1553 ub *= s;
1554 for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
1555 auto coords = getCoordinates(info.arraySizes, linearIndex);
1556 // Linear index is out of bounds, we are done.
1557 if (coords.empty())
1558 break;
1559 assert(coords.size() == info.arraySizes.size());
1560 auto position = builder.getI64ArrayAttr(coords);
1561 fun(position);
1562 }
1563 }
1564 ////////////// End Support for Lowering operations on n-D vectors //////////////
1565
1566 /// Replaces the given operation "op" with a new operation of type "targetOp"
1567 /// and given operands.
oneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)1568 LogicalResult LLVM::detail::oneToOneRewrite(
1569 Operation *op, StringRef targetOp, ValueRange operands,
1570 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
1571 unsigned numResults = op->getNumResults();
1572
1573 Type packedType;
1574 if (numResults != 0) {
1575 packedType = typeConverter.packFunctionResults(op->getResultTypes());
1576 if (!packedType)
1577 return failure();
1578 }
1579
1580 // Create the operation through state since we don't know its C++ type.
1581 OperationState state(op->getLoc(), targetOp);
1582 state.addTypes(packedType);
1583 state.addOperands(operands);
1584 state.addAttributes(op->getAttrs());
1585 Operation *newOp = rewriter.createOperation(state);
1586
1587 // If the operation produced 0 or 1 result, return them immediately.
1588 if (numResults == 0)
1589 return rewriter.eraseOp(op), success();
1590 if (numResults == 1)
1591 return rewriter.replaceOp(op, newOp->getResult(0)), success();
1592
1593 // Otherwise, it had been converted to an operation producing a structure.
1594 // Extract individual results from the structure and return them as list.
1595 SmallVector<Value, 4> results;
1596 results.reserve(numResults);
1597 for (unsigned i = 0; i < numResults; ++i) {
1598 auto type = typeConverter.convertType(op->getResult(i).getType());
1599 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
1600 op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
1601 }
1602 rewriter.replaceOp(op, results);
1603 return success();
1604 }
1605
handleMultidimensionalVectors(Operation * op,ValueRange operands,LLVMTypeConverter & typeConverter,std::function<Value (LLVM::LLVMType,ValueRange)> createOperand,ConversionPatternRewriter & rewriter)1606 static LogicalResult handleMultidimensionalVectors(
1607 Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
1608 std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
1609 ConversionPatternRewriter &rewriter) {
1610 auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
1611 if (!vectorType)
1612 return failure();
1613 auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
1614 auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
1615 auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
1616 if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
1617 return failure();
1618
1619 auto loc = op->getLoc();
1620 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
1621 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
1622 // For this unrolled `position` corresponding to the `linearIndex`^th
1623 // element, extract operand vectors
1624 SmallVector<Value, 4> extractedOperands;
1625 for (auto operand : operands)
1626 extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
1627 loc, llvmVectorTy, operand, position));
1628 Value newVal = createOperand(llvmVectorTy, extractedOperands);
1629 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
1630 position);
1631 });
1632 rewriter.replaceOp(op, desc);
1633 return success();
1634 }
1635
vectorOneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)1636 LogicalResult LLVM::detail::vectorOneToOneRewrite(
1637 Operation *op, StringRef targetOp, ValueRange operands,
1638 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
1639 assert(!operands.empty());
1640
1641 // Cannot convert ops if their operands are not of LLVM type.
1642 if (!llvm::all_of(operands.getTypes(),
1643 [](Type t) { return t.isa<LLVM::LLVMType>(); }))
1644 return failure();
1645
1646 auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
1647 if (!llvmArrayTy.isArrayTy())
1648 return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
1649
1650 auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
1651 ValueRange operands) {
1652 OperationState state(op->getLoc(), targetOp);
1653 state.addTypes(llvmVectorTy);
1654 state.addOperands(operands);
1655 state.addAttributes(op->getAttrs());
1656 return rewriter.createOperation(state)->getResult(0);
1657 };
1658
1659 return handleMultidimensionalVectors(op, operands, typeConverter, callback,
1660 rewriter);
1661 }
1662
1663 namespace {
1664 // Straightforward lowerings.
1665 using AbsFOpLowering = VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp>;
1666 using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
1667 using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
1668 using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
1669 using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
1670 using CopySignOpLowering =
1671 VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
1672 using CosOpLowering = VectorConvertToLLVMPattern<CosOp, LLVM::CosOp>;
1673 using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
1674 using ExpOpLowering = VectorConvertToLLVMPattern<ExpOp, LLVM::ExpOp>;
1675 using Exp2OpLowering = VectorConvertToLLVMPattern<Exp2Op, LLVM::Exp2Op>;
1676 using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
1677 using Log10OpLowering = VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op>;
1678 using Log2OpLowering = VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op>;
1679 using LogOpLowering = VectorConvertToLLVMPattern<LogOp, LLVM::LogOp>;
1680 using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
1681 using MulIOpLowering = VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp>;
1682 using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
1683 using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
1684 using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
1685 using SelectOpLowering = OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
1686 using ShiftLeftOpLowering =
1687 OneToOneConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp>;
1688 using SignedDivIOpLowering =
1689 VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp>;
1690 using SignedRemIOpLowering =
1691 VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp>;
1692 using SignedShiftRightOpLowering =
1693 OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
1694 using SinOpLowering = VectorConvertToLLVMPattern<SinOp, LLVM::SinOp>;
1695 using SqrtOpLowering = VectorConvertToLLVMPattern<SqrtOp, LLVM::SqrtOp>;
1696 using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
1697 using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
1698 using UnsignedDivIOpLowering =
1699 VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
1700 using UnsignedRemIOpLowering =
1701 VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp>;
1702 using UnsignedShiftRightOpLowering =
1703 OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
1704 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
1705
1706 /// Lower `std.assert`. The default lowering calls the `abort` function if the
1707 /// assertion is violated and has no effect otherwise. The failure message is
1708 /// ignored by the default lowering but should be propagated by any custom
1709 /// lowering.
1710 struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
1711 using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
1712
1713 LogicalResult
matchAndRewrite__anonf5dcde621311::AssertOpLowering1714 matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
1715 ConversionPatternRewriter &rewriter) const override {
1716 auto loc = op.getLoc();
1717 AssertOp::Adaptor transformed(operands);
1718
1719 // Insert the `abort` declaration if necessary.
1720 auto module = op->getParentOfType<ModuleOp>();
1721 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
1722 if (!abortFunc) {
1723 OpBuilder::InsertionGuard guard(rewriter);
1724 rewriter.setInsertionPointToStart(module.getBody());
1725 auto abortFuncTy =
1726 LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false);
1727 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
1728 "abort", abortFuncTy);
1729 }
1730
1731 // Split block at `assert` operation.
1732 Block *opBlock = rewriter.getInsertionBlock();
1733 auto opPosition = rewriter.getInsertionPoint();
1734 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
1735
1736 // Generate IR to call `abort`.
1737 Block *failureBlock = rewriter.createBlock(opBlock->getParent());
1738 rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
1739 rewriter.create<LLVM::UnreachableOp>(loc);
1740
1741 // Generate assertion test.
1742 rewriter.setInsertionPointToEnd(opBlock);
1743 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
1744 op, transformed.arg(), continuationBlock, failureBlock);
1745
1746 return success();
1747 }
1748 };
1749
1750 // Lowerings for operations on complex numbers.
1751
1752 struct CreateComplexOpLowering
1753 : public ConvertOpToLLVMPattern<CreateComplexOp> {
1754 using ConvertOpToLLVMPattern<CreateComplexOp>::ConvertOpToLLVMPattern;
1755
1756 LogicalResult
matchAndRewrite__anonf5dcde621311::CreateComplexOpLowering1757 matchAndRewrite(CreateComplexOp op, ArrayRef<Value> operands,
1758 ConversionPatternRewriter &rewriter) const override {
1759 auto complexOp = cast<CreateComplexOp>(op);
1760 CreateComplexOp::Adaptor transformed(operands);
1761
1762 // Pack real and imaginary part in a complex number struct.
1763 auto loc = op.getLoc();
1764 auto structType = typeConverter->convertType(complexOp.getType());
1765 auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
1766 complexStruct.setReal(rewriter, loc, transformed.real());
1767 complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
1768
1769 rewriter.replaceOp(op, {complexStruct});
1770 return success();
1771 }
1772 };
1773
1774 struct ReOpLowering : public ConvertOpToLLVMPattern<ReOp> {
1775 using ConvertOpToLLVMPattern<ReOp>::ConvertOpToLLVMPattern;
1776
1777 LogicalResult
matchAndRewrite__anonf5dcde621311::ReOpLowering1778 matchAndRewrite(ReOp op, ArrayRef<Value> operands,
1779 ConversionPatternRewriter &rewriter) const override {
1780 ReOp::Adaptor transformed(operands);
1781
1782 // Extract real part from the complex number struct.
1783 ComplexStructBuilder complexStruct(transformed.complex());
1784 Value real = complexStruct.real(rewriter, op.getLoc());
1785 rewriter.replaceOp(op, real);
1786
1787 return success();
1788 }
1789 };
1790
1791 struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
1792 using ConvertOpToLLVMPattern<ImOp>::ConvertOpToLLVMPattern;
1793
1794 LogicalResult
matchAndRewrite__anonf5dcde621311::ImOpLowering1795 matchAndRewrite(ImOp op, ArrayRef<Value> operands,
1796 ConversionPatternRewriter &rewriter) const override {
1797 ImOp::Adaptor transformed(operands);
1798
1799 // Extract imaginary part from the complex number struct.
1800 ComplexStructBuilder complexStruct(transformed.complex());
1801 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
1802 rewriter.replaceOp(op, imaginary);
1803
1804 return success();
1805 }
1806 };
1807
1808 struct BinaryComplexOperands {
1809 std::complex<Value> lhs, rhs;
1810 };
1811
1812 template <typename OpTy>
1813 BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)1814 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
1815 ConversionPatternRewriter &rewriter) {
1816 auto bop = cast<OpTy>(op);
1817 auto loc = bop.getLoc();
1818 typename OpTy::Adaptor transformed(operands);
1819
1820 // Extract real and imaginary values from operands.
1821 BinaryComplexOperands unpacked;
1822 ComplexStructBuilder lhs(transformed.lhs());
1823 unpacked.lhs.real(lhs.real(rewriter, loc));
1824 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
1825 ComplexStructBuilder rhs(transformed.rhs());
1826 unpacked.rhs.real(rhs.real(rewriter, loc));
1827 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
1828
1829 return unpacked;
1830 }
1831
1832 struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
1833 using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
1834
1835 LogicalResult
matchAndRewrite__anonf5dcde621311::AddCFOpLowering1836 matchAndRewrite(AddCFOp op, ArrayRef<Value> operands,
1837 ConversionPatternRewriter &rewriter) const override {
1838 auto loc = op.getLoc();
1839 BinaryComplexOperands arg =
1840 unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
1841
1842 // Initialize complex number struct for result.
1843 auto structType = typeConverter->convertType(op.getType());
1844 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
1845
1846 // Emit IR to add complex numbers.
1847 Value real =
1848 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real());
1849 Value imag =
1850 rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag());
1851 result.setReal(rewriter, loc, real);
1852 result.setImaginary(rewriter, loc, imag);
1853
1854 rewriter.replaceOp(op, {result});
1855 return success();
1856 }
1857 };
1858
1859 struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
1860 using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
1861
1862 LogicalResult
matchAndRewrite__anonf5dcde621311::SubCFOpLowering1863 matchAndRewrite(SubCFOp op, ArrayRef<Value> operands,
1864 ConversionPatternRewriter &rewriter) const override {
1865 auto loc = op.getLoc();
1866 BinaryComplexOperands arg =
1867 unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
1868
1869 // Initialize complex number struct for result.
1870 auto structType = typeConverter->convertType(op.getType());
1871 auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
1872
1873 // Emit IR to substract complex numbers.
1874 Value real =
1875 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real());
1876 Value imag =
1877 rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag());
1878 result.setReal(rewriter, loc, real);
1879 result.setImaginary(rewriter, loc, imag);
1880
1881 rewriter.replaceOp(op, {result});
1882 return success();
1883 }
1884 };
1885
1886 struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
1887 using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
1888
1889 LogicalResult
matchAndRewrite__anonf5dcde621311::ConstantOpLowering1890 matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
1891 ConversionPatternRewriter &rewriter) const override {
1892 // If constant refers to a function, convert it to "addressof".
1893 if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
1894 auto type = typeConverter->convertType(op.getResult().getType())
1895 .dyn_cast_or_null<LLVM::LLVMType>();
1896 if (!type)
1897 return rewriter.notifyMatchFailure(op, "failed to convert result type");
1898
1899 MutableDictionaryAttr attrs(op.getAttrs());
1900 attrs.remove(rewriter.getIdentifier("value"));
1901 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
1902 op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(),
1903 attrs.getAttrs());
1904 return success();
1905 }
1906
1907 // Calling into other scopes (non-flat reference) is not supported in LLVM.
1908 if (op.getValue().isa<SymbolRefAttr>())
1909 return rewriter.notifyMatchFailure(
1910 op, "referring to a symbol outside of the current module");
1911
1912 return LLVM::detail::oneToOneRewrite(
1913 op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(),
1914 rewriter);
1915 }
1916 };
1917
1918 /// Lowering for AllocOp and AllocaOp.
1919 struct AllocLikeOpLowering : public ConvertToLLVMPattern {
1920 using ConvertToLLVMPattern::createIndexConstant;
1921 using ConvertToLLVMPattern::getIndexType;
1922 using ConvertToLLVMPattern::getVoidPtrType;
1923
AllocLikeOpLowering__anonf5dcde621311::AllocLikeOpLowering1924 explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter)
1925 : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
1926
1927 protected:
1928 // Returns 'input' aligned up to 'alignment'. Computes
1929 // bumped = input + alignement - 1
1930 // aligned = bumped - bumped % alignment
createAligned__anonf5dcde621311::AllocLikeOpLowering1931 static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
1932 Value input, Value alignment) {
1933 Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
1934 Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
1935 Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
1936 Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
1937 return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
1938 }
1939
1940 // Creates a call to an allocation function with params and casts the
1941 // resulting void pointer to ptrType.
createAllocCall__anonf5dcde621311::AllocLikeOpLowering1942 Value createAllocCall(Location loc, StringRef name, Type ptrType,
1943 ArrayRef<Value> params, ModuleOp module,
1944 ConversionPatternRewriter &rewriter) const {
1945 SmallVector<LLVM::LLVMType, 2> paramTypes;
1946 auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1947 if (!allocFuncOp) {
1948 for (Value param : params)
1949 paramTypes.push_back(param.getType().cast<LLVM::LLVMType>());
1950 auto allocFuncType =
1951 LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes,
1952 /*isVarArg=*/false);
1953 OpBuilder::InsertionGuard guard(rewriter);
1954 rewriter.setInsertionPointToStart(module.getBody());
1955 allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
1956 name, allocFuncType);
1957 }
1958 auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp);
1959 auto allocatedPtr = rewriter
1960 .create<LLVM::CallOp>(loc, getVoidPtrType(),
1961 allocFuncSymbol, params)
1962 .getResult(0);
1963 return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
1964 }
1965
1966 /// Allocates the underlying buffer. Returns the allocated pointer and the
1967 /// aligned pointer.
1968 virtual std::tuple<Value, Value>
1969 allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
1970 Value sizeBytes, Operation *op) const = 0;
1971
1972 private:
getMemRefResultType__anonf5dcde621311::AllocLikeOpLowering1973 static MemRefType getMemRefResultType(Operation *op) {
1974 return op->getResult(0).getType().cast<MemRefType>();
1975 }
1976
match__anonf5dcde621311::AllocLikeOpLowering1977 LogicalResult match(Operation *op) const override {
1978 MemRefType memRefType = getMemRefResultType(op);
1979 return success(isSupportedMemRefType(memRefType));
1980 }
1981
1982 // An `alloc` is converted into a definition of a memref descriptor value and
1983 // a call to `malloc` to allocate the underlying data buffer. The memref
1984 // descriptor is of the LLVM structure type where:
1985 // 1. the first element is a pointer to the allocated (typed) data buffer,
1986 // 2. the second element is a pointer to the (typed) payload, aligned to the
1987 // specified alignment,
1988 // 3. the remaining elements serve to store all the sizes and strides of the
1989 // memref using LLVM-converted `index` type.
1990 //
1991 // Alignment is performed by allocating `alignment` more bytes than
1992 // requested and shifting the aligned pointer relative to the allocated
1993 // memory. Note: `alignment - <minimum malloc alignment>` would actually be
1994 // sufficient. If alignment is unspecified, the two pointers are equal.
1995
1996 // An `alloca` is converted into a definition of a memref descriptor value and
1997 // an llvm.alloca to allocate the underlying data buffer.
rewrite__anonf5dcde621311::AllocLikeOpLowering1998 void rewrite(Operation *op, ArrayRef<Value> operands,
1999 ConversionPatternRewriter &rewriter) const override {
2000 MemRefType memRefType = getMemRefResultType(op);
2001 auto loc = op->getLoc();
2002
2003 // Get actual sizes of the memref as values: static sizes are constant
2004 // values and dynamic sizes are passed to 'alloc' as operands. In case of
2005 // zero-dimensional memref, assume a scalar (size 1).
2006 SmallVector<Value, 4> sizes;
2007 SmallVector<Value, 4> strides;
2008 Value sizeBytes;
2009 this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
2010 strides, sizeBytes);
2011
2012 // Allocate the underlying buffer.
2013 Value allocatedPtr;
2014 Value alignedPtr;
2015 std::tie(allocatedPtr, alignedPtr) =
2016 this->allocateBuffer(rewriter, loc, sizeBytes, op);
2017
2018 // Create the MemRef descriptor.
2019 auto memRefDescriptor = this->createMemRefDescriptor(
2020 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
2021
2022 // Return the final value of the descriptor.
2023 rewriter.replaceOp(op, {memRefDescriptor});
2024 }
2025 };
2026
2027 struct AllocOpLowering : public AllocLikeOpLowering {
AllocOpLowering__anonf5dcde621311::AllocOpLowering2028 AllocOpLowering(LLVMTypeConverter &converter)
2029 : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
2030
allocateBuffer__anonf5dcde621311::AllocOpLowering2031 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2032 Location loc, Value sizeBytes,
2033 Operation *op) const override {
2034 // Heap allocations.
2035 AllocOp allocOp = cast<AllocOp>(op);
2036 MemRefType memRefType = allocOp.getType();
2037
2038 Value alignment;
2039 if (auto alignmentAttr = allocOp.alignment()) {
2040 alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
2041 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
2042 // In the case where no alignment is specified, we may want to override
2043 // `malloc's` behavior. `malloc` typically aligns at the size of the
2044 // biggest scalar on a target HW. For non-scalars, use the natural
2045 // alignment of the LLVM type given by the LLVM DataLayout.
2046 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
2047 }
2048
2049 if (alignment) {
2050 // Adjust the allocation size to consider alignment.
2051 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
2052 }
2053
2054 // Allocate the underlying buffer and store a pointer to it in the MemRef
2055 // descriptor.
2056 Type elementPtrType = this->getElementPtrType(memRefType);
2057 Value allocatedPtr =
2058 createAllocCall(loc, "malloc", elementPtrType, {sizeBytes},
2059 allocOp->getParentOfType<ModuleOp>(), rewriter);
2060
2061 Value alignedPtr = allocatedPtr;
2062 if (alignment) {
2063 auto intPtrType = getIntPtrType(memRefType.getMemorySpace());
2064 // Compute the aligned type pointer.
2065 Value allocatedInt =
2066 rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, allocatedPtr);
2067 Value alignmentInt =
2068 createAligned(rewriter, loc, allocatedInt, alignment);
2069 alignedPtr =
2070 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
2071 }
2072
2073 return std::make_tuple(allocatedPtr, alignedPtr);
2074 }
2075 };
2076
2077 struct AlignedAllocOpLowering : public AllocLikeOpLowering {
AlignedAllocOpLowering__anonf5dcde621311::AlignedAllocOpLowering2078 AlignedAllocOpLowering(LLVMTypeConverter &converter)
2079 : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
2080
2081 /// Returns the memref's element size in bytes.
2082 // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anonf5dcde621311::AlignedAllocOpLowering2083 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
2084 auto elementType = memRefType.getElementType();
2085
2086 unsigned sizeInBits;
2087 if (elementType.isIntOrFloat()) {
2088 sizeInBits = elementType.getIntOrFloatBitWidth();
2089 } else {
2090 auto vectorType = elementType.cast<VectorType>();
2091 sizeInBits =
2092 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
2093 }
2094 return llvm::divideCeil(sizeInBits, 8);
2095 }
2096
2097 /// Returns true if the memref size in bytes is known to be a multiple of
2098 /// factor.
isMemRefSizeMultipleOf__anonf5dcde621311::AlignedAllocOpLowering2099 static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) {
2100 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type);
2101 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
2102 if (type.isDynamic(type.getDimSize(i)))
2103 continue;
2104 sizeDivisor = sizeDivisor * type.getDimSize(i);
2105 }
2106 return sizeDivisor % factor == 0;
2107 }
2108
2109 /// Returns the alignment to be used for the allocation call itself.
2110 /// aligned_alloc requires the allocation size to be a power of two, and the
2111 /// allocation size to be a multiple of alignment,
getAllocationAlignment__anonf5dcde621311::AlignedAllocOpLowering2112 int64_t getAllocationAlignment(AllocOp allocOp) const {
2113 if (Optional<uint64_t> alignment = allocOp.alignment())
2114 return *alignment;
2115
2116 // Whenever we don't have alignment set, we will use an alignment
2117 // consistent with the element type; since the allocation size has to be a
2118 // power of two, we will bump to the next power of two if it already isn't.
2119 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType());
2120 return std::max(kMinAlignedAllocAlignment,
2121 llvm::PowerOf2Ceil(eltSizeBytes));
2122 }
2123
allocateBuffer__anonf5dcde621311::AlignedAllocOpLowering2124 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2125 Location loc, Value sizeBytes,
2126 Operation *op) const override {
2127 // Heap allocations.
2128 AllocOp allocOp = cast<AllocOp>(op);
2129 MemRefType memRefType = allocOp.getType();
2130 int64_t alignment = getAllocationAlignment(allocOp);
2131 Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
2132
2133 // aligned_alloc requires size to be a multiple of alignment; we will pad
2134 // the size to the next multiple if necessary.
2135 if (!isMemRefSizeMultipleOf(memRefType, alignment))
2136 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
2137
2138 Type elementPtrType = this->getElementPtrType(memRefType);
2139 Value allocatedPtr = createAllocCall(
2140 loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes},
2141 allocOp->getParentOfType<ModuleOp>(), rewriter);
2142
2143 return std::make_tuple(allocatedPtr, allocatedPtr);
2144 }
2145
2146 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
2147 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
2148 };
2149
2150 // Out of line definition, required till C++17.
2151 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
2152
2153 struct AllocaOpLowering : public AllocLikeOpLowering {
AllocaOpLowering__anonf5dcde621311::AllocaOpLowering2154 AllocaOpLowering(LLVMTypeConverter &converter)
2155 : AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {}
2156
2157 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
2158 /// is set to null for stack allocations. `accessAlignment` is set if
2159 /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anonf5dcde621311::AllocaOpLowering2160 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2161 Location loc, Value sizeBytes,
2162 Operation *op) const override {
2163
2164 // With alloca, one gets a pointer to the element type right away.
2165 // For stack allocations.
2166 auto allocaOp = cast<AllocaOp>(op);
2167 auto elementPtrType = this->getElementPtrType(allocaOp.getType());
2168
2169 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
2170 loc, elementPtrType, sizeBytes,
2171 allocaOp.alignment() ? *allocaOp.alignment() : 0);
2172
2173 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
2174 }
2175 };
2176
2177 /// Copies the shaped descriptor part to (if `toDynamic` is set) or from
2178 /// (otherwise) the dynamically allocated memory for any operands that were
2179 /// unranked descriptors originally.
copyUnrankedDescriptors(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,TypeRange origTypes,SmallVectorImpl<Value> & operands,bool toDynamic)2180 static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
2181 LLVMTypeConverter &typeConverter,
2182 TypeRange origTypes,
2183 SmallVectorImpl<Value> &operands,
2184 bool toDynamic) {
2185 assert(origTypes.size() == operands.size() &&
2186 "expected as may original types as operands");
2187
2188 // Find operands of unranked memref type and store them.
2189 SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
2190 for (unsigned i = 0, e = operands.size(); i < e; ++i)
2191 if (origTypes[i].isa<UnrankedMemRefType>())
2192 unrankedMemrefs.emplace_back(operands[i]);
2193
2194 if (unrankedMemrefs.empty())
2195 return success();
2196
2197 // Compute allocation sizes.
2198 SmallVector<Value, 4> sizes;
2199 UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter,
2200 unrankedMemrefs, sizes);
2201
2202 // Get frequently used types.
2203 MLIRContext *context = builder.getContext();
2204 auto voidType = LLVM::LLVMType::getVoidTy(context);
2205 auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context);
2206 auto i1Type = LLVM::LLVMType::getInt1Ty(context);
2207 LLVM::LLVMType indexType = typeConverter.getIndexType();
2208
2209 // Find the malloc and free, or declare them if necessary.
2210 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
2211 auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
2212 if (!mallocFunc && toDynamic) {
2213 OpBuilder::InsertionGuard guard(builder);
2214 builder.setInsertionPointToStart(module.getBody());
2215 mallocFunc = builder.create<LLVM::LLVMFuncOp>(
2216 builder.getUnknownLoc(), "malloc",
2217 LLVM::LLVMType::getFunctionTy(
2218 voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false));
2219 }
2220 auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free");
2221 if (!freeFunc && !toDynamic) {
2222 OpBuilder::InsertionGuard guard(builder);
2223 builder.setInsertionPointToStart(module.getBody());
2224 freeFunc = builder.create<LLVM::LLVMFuncOp>(
2225 builder.getUnknownLoc(), "free",
2226 LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType),
2227 /*isVarArg=*/false));
2228 }
2229
2230 // Initialize shared constants.
2231 Value zero =
2232 builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
2233
2234 unsigned unrankedMemrefPos = 0;
2235 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
2236 Type type = origTypes[i];
2237 if (!type.isa<UnrankedMemRefType>())
2238 continue;
2239 Value allocationSize = sizes[unrankedMemrefPos++];
2240 UnrankedMemRefDescriptor desc(operands[i]);
2241
2242 // Allocate memory, copy, and free the source if necessary.
2243 Value memory =
2244 toDynamic
2245 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
2246 .getResult(0)
2247 : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
2248 /*alignment=*/0);
2249
2250 Value source = desc.memRefDescPtr(builder, loc);
2251 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
2252 if (!toDynamic)
2253 builder.create<LLVM::CallOp>(loc, freeFunc, source);
2254
2255 // Create a new descriptor. The same descriptor can be returned multiple
2256 // times, attempting to modify its pointer can lead to memory leaks
2257 // (allocated twice and overwritten) or double frees (the caller does not
2258 // know if the descriptor points to the same memory).
2259 Type descriptorType = typeConverter.convertType(type);
2260 if (!descriptorType)
2261 return failure();
2262 auto updatedDesc =
2263 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
2264 Value rank = desc.rank(builder, loc);
2265 updatedDesc.setRank(builder, loc, rank);
2266 updatedDesc.setMemRefDescPtr(builder, loc, memory);
2267
2268 operands[i] = updatedDesc;
2269 }
2270
2271 return success();
2272 }
2273
2274 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
2275 // passes the pointer to the MemRef across function boundaries.
2276 template <typename CallOpType>
2277 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
2278 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
2279 using Super = CallOpInterfaceLowering<CallOpType>;
2280 using Base = ConvertOpToLLVMPattern<CallOpType>;
2281
2282 LogicalResult
matchAndRewrite__anonf5dcde621311::CallOpInterfaceLowering2283 matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
2284 ConversionPatternRewriter &rewriter) const override {
2285 typename CallOpType::Adaptor transformed(operands);
2286
2287 // Pack the result types into a struct.
2288 Type packedResult = nullptr;
2289 unsigned numResults = callOp.getNumResults();
2290 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
2291
2292 if (numResults != 0) {
2293 if (!(packedResult =
2294 this->getTypeConverter()->packFunctionResults(resultTypes)))
2295 return failure();
2296 }
2297
2298 auto promoted = this->getTypeConverter()->promoteOperands(
2299 callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands,
2300 rewriter);
2301 auto newOp = rewriter.create<LLVM::CallOp>(
2302 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
2303 promoted, callOp.getAttrs());
2304
2305 SmallVector<Value, 4> results;
2306 if (numResults < 2) {
2307 // If < 2 results, packing did not do anything and we can just return.
2308 results.append(newOp.result_begin(), newOp.result_end());
2309 } else {
2310 // Otherwise, it had been converted to an operation producing a structure.
2311 // Extract individual results from the structure and return them as list.
2312 results.reserve(numResults);
2313 for (unsigned i = 0; i < numResults; ++i) {
2314 auto type =
2315 this->typeConverter->convertType(callOp.getResult(i).getType());
2316 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
2317 callOp.getLoc(), type, newOp->getResult(0),
2318 rewriter.getI64ArrayAttr(i)));
2319 }
2320 }
2321
2322 if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
2323 // For the bare-ptr calling convention, promote memref results to
2324 // descriptors.
2325 assert(results.size() == resultTypes.size() &&
2326 "The number of arguments and types doesn't match");
2327 this->getTypeConverter()->promoteBarePtrsToDescriptors(
2328 rewriter, callOp.getLoc(), resultTypes, results);
2329 } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
2330 *this->getTypeConverter(),
2331 resultTypes, results,
2332 /*toDynamic=*/false))) {
2333 return failure();
2334 }
2335
2336 rewriter.replaceOp(callOp, results);
2337 return success();
2338 }
2339 };
2340
2341 struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
2342 using Super::Super;
2343 };
2344
2345 struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
2346 using Super::Super;
2347 };
2348
2349 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
2350 // The memref descriptor being an SSA value, there is no need to clean it up
2351 // in any way.
2352 struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
2353 using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
2354
DeallocOpLowering__anonf5dcde621311::DeallocOpLowering2355 explicit DeallocOpLowering(LLVMTypeConverter &converter)
2356 : ConvertOpToLLVMPattern<DeallocOp>(converter) {}
2357
2358 LogicalResult
matchAndRewrite__anonf5dcde621311::DeallocOpLowering2359 matchAndRewrite(DeallocOp op, ArrayRef<Value> operands,
2360 ConversionPatternRewriter &rewriter) const override {
2361 assert(operands.size() == 1 && "dealloc takes one operand");
2362 DeallocOp::Adaptor transformed(operands);
2363
2364 // Insert the `free` declaration if it is not already present.
2365 auto freeFunc =
2366 op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
2367 if (!freeFunc) {
2368 OpBuilder::InsertionGuard guard(rewriter);
2369 rewriter.setInsertionPointToStart(
2370 op->getParentOfType<ModuleOp>().getBody());
2371 freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
2372 rewriter.getUnknownLoc(), "free",
2373 LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
2374 /*isVarArg=*/false));
2375 }
2376
2377 MemRefDescriptor memref(transformed.memref());
2378 Value casted = rewriter.create<LLVM::BitcastOp>(
2379 op.getLoc(), getVoidPtrType(),
2380 memref.allocatedPtr(rewriter, op.getLoc()));
2381 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
2382 op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
2383 return success();
2384 }
2385 };
2386
2387 /// Returns the LLVM type of the global variable given the memref type `type`.
2388 static LLVM::LLVMType
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)2389 convertGlobalMemrefTypeToLLVM(MemRefType type,
2390 LLVMTypeConverter &typeConverter) {
2391 // LLVM type for a global memref will be a multi-dimension array. For
2392 // declarations or uninitialized global memrefs, we can potentially flatten
2393 // this to a 1D array. However, for global_memref's with an initial value,
2394 // we do not intend to flatten the ElementsAttribute when going from std ->
2395 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
2396 LLVM::LLVMType elementType =
2397 unwrap(typeConverter.convertType(type.getElementType()));
2398 LLVM::LLVMType arrayTy = elementType;
2399 // Shape has the outermost dim at index 0, so need to walk it backwards
2400 for (int64_t dim : llvm::reverse(type.getShape()))
2401 arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim);
2402 return arrayTy;
2403 }
2404
2405 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
2406 struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
2407 using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
2408
2409 LogicalResult
matchAndRewrite__anonf5dcde621311::GlobalMemrefOpLowering2410 matchAndRewrite(GlobalMemrefOp global, ArrayRef<Value> operands,
2411 ConversionPatternRewriter &rewriter) const override {
2412 MemRefType type = global.type().cast<MemRefType>();
2413 if (!isSupportedMemRefType(type))
2414 return failure();
2415
2416 LLVM::LLVMType arrayTy =
2417 convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
2418
2419 LLVM::Linkage linkage =
2420 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
2421
2422 Attribute initialValue = nullptr;
2423 if (!global.isExternal() && !global.isUninitialized()) {
2424 auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
2425 initialValue = elementsAttr;
2426
2427 // For scalar memrefs, the global variable created is of the element type,
2428 // so unpack the elements attribute to extract the value.
2429 if (type.getRank() == 0)
2430 initialValue = elementsAttr.getValue({});
2431 }
2432
2433 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
2434 global, arrayTy, global.constant(), linkage, global.sym_name(),
2435 initialValue, type.getMemorySpace());
2436 return success();
2437 }
2438 };
2439
2440 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
2441 /// the first element stashed into the descriptor. This reuses
2442 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
2443 struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
GetGlobalMemrefOpLowering__anonf5dcde621311::GetGlobalMemrefOpLowering2444 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
2445 : AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {}
2446
2447 /// Buffer "allocation" for get_global_memref op is getting the address of
2448 /// the global variable referenced.
allocateBuffer__anonf5dcde621311::GetGlobalMemrefOpLowering2449 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2450 Location loc, Value sizeBytes,
2451 Operation *op) const override {
2452 auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
2453 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
2454 unsigned memSpace = type.getMemorySpace();
2455
2456 LLVM::LLVMType arrayTy =
2457 convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
2458 auto addressOf = rewriter.create<LLVM::AddressOfOp>(
2459 loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
2460
2461 // Get the address of the first element in the array by creating a GEP with
2462 // the address of the GV as the base, and (rank + 1) number of 0 indices.
2463 LLVM::LLVMType elementType =
2464 unwrap(typeConverter->convertType(type.getElementType()));
2465 LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
2466
2467 SmallVector<Value, 4> operands = {addressOf};
2468 operands.insert(operands.end(), type.getRank() + 1,
2469 createIndexConstant(rewriter, loc, 0));
2470 auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
2471
2472 // We do not expect the memref obtained using `get_global_memref` to be
2473 // ever deallocated. Set the allocated pointer to be known bad value to
2474 // help debug if that ever happens.
2475 auto intPtrType = getIntPtrType(memSpace);
2476 Value deadBeefConst =
2477 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
2478 auto deadBeefPtr =
2479 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
2480
2481 // Both allocated and aligned pointers are same. We could potentially stash
2482 // a nullptr for the allocated pointer since we do not expect any dealloc.
2483 return std::make_tuple(deadBeefPtr, gep);
2484 }
2485 };
2486
2487 // A `rsqrt` is converted into `1 / sqrt`.
2488 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
2489 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
2490
2491 LogicalResult
matchAndRewrite__anonf5dcde621311::RsqrtOpLowering2492 matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
2493 ConversionPatternRewriter &rewriter) const override {
2494 RsqrtOp::Adaptor transformed(operands);
2495 auto operandType =
2496 transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
2497
2498 if (!operandType)
2499 return failure();
2500
2501 auto loc = op.getLoc();
2502 auto resultType = op.getResult().getType();
2503 auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
2504 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
2505
2506 if (!operandType.isArrayTy()) {
2507 LLVM::ConstantOp one;
2508 if (operandType.isVectorTy()) {
2509 one = rewriter.create<LLVM::ConstantOp>(
2510 loc, operandType,
2511 SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
2512 } else {
2513 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
2514 }
2515 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
2516 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
2517 return success();
2518 }
2519
2520 auto vectorType = resultType.dyn_cast<VectorType>();
2521 if (!vectorType)
2522 return failure();
2523
2524 return handleMultidimensionalVectors(
2525 op.getOperation(), operands, *getTypeConverter(),
2526 [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
2527 auto splatAttr = SplatElementsAttr::get(
2528 mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
2529 floatType),
2530 floatOne);
2531 auto one =
2532 rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
2533 auto sqrt =
2534 rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
2535 return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
2536 },
2537 rewriter);
2538 }
2539 };
2540
2541 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
2542 using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
2543
match__anonf5dcde621311::MemRefCastOpLowering2544 LogicalResult match(MemRefCastOp memRefCastOp) const override {
2545 Type srcType = memRefCastOp.getOperand().getType();
2546 Type dstType = memRefCastOp.getType();
2547
2548 // MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used
2549 // for type erasure. For now they must preserve underlying element type and
2550 // require source and result type to have the same rank. Therefore, perform
2551 // a sanity check that the underlying structs are the same. Once op
2552 // semantics are relaxed we can revisit.
2553 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
2554 return success(typeConverter->convertType(srcType) ==
2555 typeConverter->convertType(dstType));
2556
2557 // At least one of the operands is unranked type
2558 assert(srcType.isa<UnrankedMemRefType>() ||
2559 dstType.isa<UnrankedMemRefType>());
2560
2561 // Unranked to unranked cast is disallowed
2562 return !(srcType.isa<UnrankedMemRefType>() &&
2563 dstType.isa<UnrankedMemRefType>())
2564 ? success()
2565 : failure();
2566 }
2567
rewrite__anonf5dcde621311::MemRefCastOpLowering2568 void rewrite(MemRefCastOp memRefCastOp, ArrayRef<Value> operands,
2569 ConversionPatternRewriter &rewriter) const override {
2570 MemRefCastOp::Adaptor transformed(operands);
2571
2572 auto srcType = memRefCastOp.getOperand().getType();
2573 auto dstType = memRefCastOp.getType();
2574 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
2575 auto loc = memRefCastOp.getLoc();
2576
2577 // For ranked/ranked case, just keep the original descriptor.
2578 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
2579 return rewriter.replaceOp(memRefCastOp, {transformed.source()});
2580
2581 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
2582 // Casting ranked to unranked memref type
2583 // Set the rank in the destination from the memref type
2584 // Allocate space on the stack and copy the src memref descriptor
2585 // Set the ptr in the destination to the stack space
2586 auto srcMemRefType = srcType.cast<MemRefType>();
2587 int64_t rank = srcMemRefType.getRank();
2588 // ptr = AllocaOp sizeof(MemRefDescriptor)
2589 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
2590 loc, transformed.source(), rewriter);
2591 // voidptr = BitCastOp srcType* to void*
2592 auto voidPtr =
2593 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
2594 .getResult();
2595 // rank = ConstantOp srcRank
2596 auto rankVal = rewriter.create<LLVM::ConstantOp>(
2597 loc, typeConverter->convertType(rewriter.getIntegerType(64)),
2598 rewriter.getI64IntegerAttr(rank));
2599 // undef = UndefOp
2600 UnrankedMemRefDescriptor memRefDesc =
2601 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
2602 // d1 = InsertValueOp undef, rank, 0
2603 memRefDesc.setRank(rewriter, loc, rankVal);
2604 // d2 = InsertValueOp d1, voidptr, 1
2605 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
2606 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
2607
2608 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
2609 // Casting from unranked type to ranked.
2610 // The operation is assumed to be doing a correct cast. If the destination
2611 // type mismatches the unranked the type, it is undefined behavior.
2612 UnrankedMemRefDescriptor memRefDesc(transformed.source());
2613 // ptr = ExtractValueOp src, 1
2614 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
2615 // castPtr = BitCastOp i8* to structTy*
2616 auto castPtr =
2617 rewriter
2618 .create<LLVM::BitcastOp>(
2619 loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
2620 ptr)
2621 .getResult();
2622 // struct = LoadOp castPtr
2623 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
2624 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
2625 } else {
2626 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
2627 }
2628 }
2629 };
2630
2631 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
2632 /// memref type. In unranked case, the fields are extracted from the underlying
2633 /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)2634 static void extractPointersAndOffset(Location loc,
2635 ConversionPatternRewriter &rewriter,
2636 LLVMTypeConverter &typeConverter,
2637 Value originalOperand,
2638 Value convertedOperand,
2639 Value *allocatedPtr, Value *alignedPtr,
2640 Value *offset = nullptr) {
2641 Type operandType = originalOperand.getType();
2642 if (operandType.isa<MemRefType>()) {
2643 MemRefDescriptor desc(convertedOperand);
2644 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
2645 *alignedPtr = desc.alignedPtr(rewriter, loc);
2646 if (offset != nullptr)
2647 *offset = desc.offset(rewriter, loc);
2648 return;
2649 }
2650
2651 unsigned memorySpace =
2652 operandType.cast<UnrankedMemRefType>().getMemorySpace();
2653 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
2654 LLVM::LLVMType llvmElementType =
2655 unwrap(typeConverter.convertType(elementType));
2656 LLVM::LLVMType elementPtrPtrType =
2657 llvmElementType.getPointerTo(memorySpace).getPointerTo();
2658
2659 // Extract pointer to the underlying ranked memref descriptor and cast it to
2660 // ElemType**.
2661 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
2662 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
2663
2664 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
2665 rewriter, loc, underlyingDescPtr, elementPtrPtrType);
2666 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
2667 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
2668 if (offset != nullptr) {
2669 *offset = UnrankedMemRefDescriptor::offset(
2670 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
2671 }
2672 }
2673
2674 struct MemRefReinterpretCastOpLowering
2675 : public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
2676 using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
2677
2678 LogicalResult
matchAndRewrite__anonf5dcde621311::MemRefReinterpretCastOpLowering2679 matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef<Value> operands,
2680 ConversionPatternRewriter &rewriter) const override {
2681 MemRefReinterpretCastOp::Adaptor adaptor(operands,
2682 castOp->getAttrDictionary());
2683 Type srcType = castOp.source().getType();
2684
2685 Value descriptor;
2686 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
2687 adaptor, &descriptor)))
2688 return failure();
2689 rewriter.replaceOp(castOp, {descriptor});
2690 return success();
2691 }
2692
2693 private:
2694 LogicalResult
convertSourceMemRefToDescriptor__anonf5dcde621311::MemRefReinterpretCastOpLowering2695 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
2696 Type srcType, MemRefReinterpretCastOp castOp,
2697 MemRefReinterpretCastOp::Adaptor adaptor,
2698 Value *descriptor) const {
2699 MemRefType targetMemRefType =
2700 castOp.getResult().getType().cast<MemRefType>();
2701 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
2702 .dyn_cast_or_null<LLVM::LLVMType>();
2703 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
2704 return failure();
2705
2706 // Create descriptor.
2707 Location loc = castOp.getLoc();
2708 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
2709
2710 // Set allocated and aligned pointers.
2711 Value allocatedPtr, alignedPtr;
2712 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
2713 castOp.source(), adaptor.source(), &allocatedPtr,
2714 &alignedPtr);
2715 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
2716 desc.setAlignedPtr(rewriter, loc, alignedPtr);
2717
2718 // Set offset.
2719 if (castOp.isDynamicOffset(0))
2720 desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
2721 else
2722 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
2723
2724 // Set sizes and strides.
2725 unsigned dynSizeId = 0;
2726 unsigned dynStrideId = 0;
2727 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
2728 if (castOp.isDynamicSize(i))
2729 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
2730 else
2731 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
2732
2733 if (castOp.isDynamicStride(i))
2734 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
2735 else
2736 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
2737 }
2738 *descriptor = desc;
2739 return success();
2740 }
2741 };
2742
2743 struct MemRefReshapeOpLowering
2744 : public ConvertOpToLLVMPattern<MemRefReshapeOp> {
2745 using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
2746
2747 LogicalResult
matchAndRewrite__anonf5dcde621311::MemRefReshapeOpLowering2748 matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef<Value> operands,
2749 ConversionPatternRewriter &rewriter) const override {
2750 auto *op = reshapeOp.getOperation();
2751 MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
2752 Type srcType = reshapeOp.source().getType();
2753
2754 Value descriptor;
2755 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
2756 adaptor, &descriptor)))
2757 return failure();
2758 rewriter.replaceOp(op, {descriptor});
2759 return success();
2760 }
2761
2762 private:
2763 LogicalResult
convertSourceMemRefToDescriptor__anonf5dcde621311::MemRefReshapeOpLowering2764 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
2765 Type srcType, MemRefReshapeOp reshapeOp,
2766 MemRefReshapeOp::Adaptor adaptor,
2767 Value *descriptor) const {
2768 // Conversion for statically-known shape args is performed via
2769 // `memref_reinterpret_cast`.
2770 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
2771 if (shapeMemRefType.hasStaticShape())
2772 return failure();
2773
2774 // The shape is a rank-1 tensor with unknown length.
2775 Location loc = reshapeOp.getLoc();
2776 MemRefDescriptor shapeDesc(adaptor.shape());
2777 Value resultRank = shapeDesc.size(rewriter, loc, 0);
2778
2779 // Extract address space and element type.
2780 auto targetType =
2781 reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
2782 unsigned addressSpace = targetType.getMemorySpace();
2783 Type elementType = targetType.getElementType();
2784
2785 // Create the unranked memref descriptor that holds the ranked one. The
2786 // inner descriptor is allocated on stack.
2787 auto targetDesc = UnrankedMemRefDescriptor::undef(
2788 rewriter, loc, unwrap(typeConverter->convertType(targetType)));
2789 targetDesc.setRank(rewriter, loc, resultRank);
2790 SmallVector<Value, 4> sizes;
2791 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
2792 targetDesc, sizes);
2793 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
2794 loc, getVoidPtrType(), sizes.front(), llvm::None);
2795 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
2796
2797 // Extract pointers and offset from the source memref.
2798 Value allocatedPtr, alignedPtr, offset;
2799 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
2800 reshapeOp.source(), adaptor.source(),
2801 &allocatedPtr, &alignedPtr, &offset);
2802
2803 // Set pointers and offset.
2804 LLVM::LLVMType llvmElementType =
2805 unwrap(typeConverter->convertType(elementType));
2806 LLVM::LLVMType elementPtrPtrType =
2807 llvmElementType.getPointerTo(addressSpace).getPointerTo();
2808 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
2809 elementPtrPtrType, allocatedPtr);
2810 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
2811 underlyingDescPtr,
2812 elementPtrPtrType, alignedPtr);
2813 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
2814 underlyingDescPtr, elementPtrPtrType,
2815 offset);
2816
2817 // Use the offset pointer as base for further addressing. Copy over the new
2818 // shape and compute strides. For this, we create a loop from rank-1 to 0.
2819 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
2820 rewriter, loc, *getTypeConverter(), underlyingDescPtr,
2821 elementPtrPtrType);
2822 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
2823 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
2824 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
2825 Value oneIndex = createIndexConstant(rewriter, loc, 1);
2826 Value resultRankMinusOne =
2827 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
2828
2829 Block *initBlock = rewriter.getInsertionBlock();
2830 LLVM::LLVMType indexType = getTypeConverter()->getIndexType();
2831 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
2832
2833 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
2834 {indexType, indexType});
2835
2836 // Iterate over the remaining ops in initBlock and move them to condBlock.
2837 BlockAndValueMapping map;
2838 for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) {
2839 rewriter.clone(*it, map);
2840 rewriter.eraseOp(&*it);
2841 }
2842
2843 rewriter.setInsertionPointToEnd(initBlock);
2844 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
2845 condBlock);
2846 rewriter.setInsertionPointToStart(condBlock);
2847 Value indexArg = condBlock->getArgument(0);
2848 Value strideArg = condBlock->getArgument(1);
2849
2850 Value zeroIndex = createIndexConstant(rewriter, loc, 0);
2851 Value pred = rewriter.create<LLVM::ICmpOp>(
2852 loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
2853 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
2854
2855 Block *bodyBlock =
2856 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
2857 rewriter.setInsertionPointToStart(bodyBlock);
2858
2859 // Copy size from shape to descriptor.
2860 LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo();
2861 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
2862 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
2863 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
2864 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
2865 targetSizesBase, indexArg, size);
2866
2867 // Write stride value and compute next one.
2868 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
2869 targetStridesBase, indexArg, strideArg);
2870 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
2871
2872 // Decrement loop counter and branch back.
2873 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
2874 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
2875 condBlock);
2876
2877 Block *remainder =
2878 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
2879
2880 // Hook up the cond exit to the remainder.
2881 rewriter.setInsertionPointToEnd(condBlock);
2882 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
2883 llvm::None);
2884
2885 // Reset position to beginning of new remainder block.
2886 rewriter.setInsertionPointToStart(remainder);
2887
2888 *descriptor = targetDesc;
2889 return success();
2890 }
2891 };
2892
2893 struct DialectCastOpLowering
2894 : public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
2895 using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
2896
2897 LogicalResult
matchAndRewrite__anonf5dcde621311::DialectCastOpLowering2898 matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef<Value> operands,
2899 ConversionPatternRewriter &rewriter) const override {
2900 LLVM::DialectCastOp::Adaptor transformed(operands);
2901 if (transformed.in().getType() !=
2902 typeConverter->convertType(castOp.getType())) {
2903 return failure();
2904 }
2905 rewriter.replaceOp(castOp, transformed.in());
2906 return success();
2907 }
2908 };
2909
2910 // A `dim` is converted to a constant for static sizes and to an access to the
2911 // size stored in the memref descriptor for dynamic sizes.
2912 struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
2913 using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
2914
2915 LogicalResult
matchAndRewrite__anonf5dcde621311::DimOpLowering2916 matchAndRewrite(DimOp dimOp, ArrayRef<Value> operands,
2917 ConversionPatternRewriter &rewriter) const override {
2918 Type operandType = dimOp.memrefOrTensor().getType();
2919 if (operandType.isa<UnrankedMemRefType>()) {
2920 rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
2921 operandType, dimOp, operands, rewriter)});
2922
2923 return success();
2924 }
2925 if (operandType.isa<MemRefType>()) {
2926 rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
2927 operandType, dimOp, operands, rewriter)});
2928 return success();
2929 }
2930 return failure();
2931 }
2932
2933 private:
extractSizeOfUnrankedMemRef__anonf5dcde621311::DimOpLowering2934 Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp,
2935 ArrayRef<Value> operands,
2936 ConversionPatternRewriter &rewriter) const {
2937 Location loc = dimOp.getLoc();
2938 DimOp::Adaptor transformed(operands);
2939
2940 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
2941 auto scalarMemRefType =
2942 MemRefType::get({}, unrankedMemRefType.getElementType());
2943 unsigned addressSpace = unrankedMemRefType.getMemorySpace();
2944
2945 // Extract pointer to the underlying ranked descriptor and bitcast it to a
2946 // memref<element_type> descriptor pointer to minimize the number of GEP
2947 // operations.
2948 UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor());
2949 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
2950 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
2951 loc,
2952 typeConverter->convertType(scalarMemRefType)
2953 .cast<LLVM::LLVMType>()
2954 .getPointerTo(addressSpace),
2955 underlyingRankedDesc);
2956
2957 // Get pointer to offset field of memref<element_type> descriptor.
2958 Type indexPtrTy =
2959 getTypeConverter()->getIndexType().getPointerTo(addressSpace);
2960 Value two = rewriter.create<LLVM::ConstantOp>(
2961 loc, typeConverter->convertType(rewriter.getI32Type()),
2962 rewriter.getI32IntegerAttr(2));
2963 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
2964 loc, indexPtrTy, scalarMemRefDescPtr,
2965 ValueRange({createIndexConstant(rewriter, loc, 0), two}));
2966
2967 // The size value that we have to extract can be obtained using GEPop with
2968 // `dimOp.index() + 1` index argument.
2969 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
2970 loc, createIndexConstant(rewriter, loc, 1), transformed.index());
2971 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
2972 ValueRange({idxPlusOne}));
2973 return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
2974 }
2975
extractSizeOfRankedMemRef__anonf5dcde621311::DimOpLowering2976 Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp,
2977 ArrayRef<Value> operands,
2978 ConversionPatternRewriter &rewriter) const {
2979 Location loc = dimOp.getLoc();
2980 DimOp::Adaptor transformed(operands);
2981 // Take advantage if index is constant.
2982 MemRefType memRefType = operandType.cast<MemRefType>();
2983 if (Optional<int64_t> index = dimOp.getConstantIndex()) {
2984 int64_t i = index.getValue();
2985 if (memRefType.isDynamicDim(i)) {
2986 // extract dynamic size from the memref descriptor.
2987 MemRefDescriptor descriptor(transformed.memrefOrTensor());
2988 return descriptor.size(rewriter, loc, i);
2989 }
2990 // Use constant for static size.
2991 int64_t dimSize = memRefType.getDimSize(i);
2992 return createIndexConstant(rewriter, loc, dimSize);
2993 }
2994 Value index = dimOp.index();
2995 int64_t rank = memRefType.getRank();
2996 MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
2997 return memrefDescriptor.size(rewriter, loc, index, rank);
2998 }
2999 };
3000
3001 struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
3002 using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
3003
3004 LogicalResult
matchAndRewrite__anonf5dcde621311::RankOpLowering3005 matchAndRewrite(RankOp op, ArrayRef<Value> operands,
3006 ConversionPatternRewriter &rewriter) const override {
3007 Location loc = op.getLoc();
3008 Type operandType = op.memrefOrTensor().getType();
3009 if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
3010 UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
3011 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
3012 return success();
3013 }
3014 if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
3015 rewriter.replaceOp(
3016 op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
3017 return success();
3018 }
3019 return failure();
3020 }
3021 };
3022
3023 // Common base for load and store operations on MemRefs. Restricts the match
3024 // to supported MemRef types. Provides functionality to emit code accessing a
3025 // specific element of the underlying data buffer.
3026 template <typename Derived>
3027 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
3028 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
3029 using ConvertOpToLLVMPattern<Derived>::isSupportedMemRefType;
3030 using Base = LoadStoreOpLowering<Derived>;
3031
match__anonf5dcde621311::LoadStoreOpLowering3032 LogicalResult match(Derived op) const override {
3033 MemRefType type = op.getMemRefType();
3034 return isSupportedMemRefType(type) ? success() : failure();
3035 }
3036 };
3037
3038 // Load operation is lowered to obtaining a pointer to the indexed element
3039 // and loading it.
3040 struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
3041 using Base::Base;
3042
3043 LogicalResult
matchAndRewrite__anonf5dcde621311::LoadOpLowering3044 matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
3045 ConversionPatternRewriter &rewriter) const override {
3046 LoadOp::Adaptor transformed(operands);
3047 auto type = loadOp.getMemRefType();
3048
3049 Value dataPtr =
3050 getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
3051 transformed.indices(), rewriter);
3052 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
3053 return success();
3054 }
3055 };
3056
3057 // Store operation is lowered to obtaining a pointer to the indexed element,
3058 // and storing the given value to it.
3059 struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
3060 using Base::Base;
3061
3062 LogicalResult
matchAndRewrite__anonf5dcde621311::StoreOpLowering3063 matchAndRewrite(StoreOp op, ArrayRef<Value> operands,
3064 ConversionPatternRewriter &rewriter) const override {
3065 auto type = op.getMemRefType();
3066 StoreOp::Adaptor transformed(operands);
3067
3068 Value dataPtr =
3069 getStridedElementPtr(op.getLoc(), type, transformed.memref(),
3070 transformed.indices(), rewriter);
3071 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
3072 dataPtr);
3073 return success();
3074 }
3075 };
3076
3077 // The prefetch operation is lowered in a way similar to the load operation
3078 // except that the llvm.prefetch operation is used for replacement.
3079 struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
3080 using Base::Base;
3081
3082 LogicalResult
matchAndRewrite__anonf5dcde621311::PrefetchOpLowering3083 matchAndRewrite(PrefetchOp prefetchOp, ArrayRef<Value> operands,
3084 ConversionPatternRewriter &rewriter) const override {
3085 PrefetchOp::Adaptor transformed(operands);
3086 auto type = prefetchOp.getMemRefType();
3087 auto loc = prefetchOp.getLoc();
3088
3089 Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
3090 transformed.indices(), rewriter);
3091
3092 // Replace with llvm.prefetch.
3093 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
3094 auto isWrite = rewriter.create<LLVM::ConstantOp>(
3095 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
3096 auto localityHint = rewriter.create<LLVM::ConstantOp>(
3097 loc, llvmI32Type,
3098 rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
3099 auto isData = rewriter.create<LLVM::ConstantOp>(
3100 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
3101
3102 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
3103 localityHint, isData);
3104 return success();
3105 }
3106 };
3107
3108 // The lowering of index_cast becomes an integer conversion since index becomes
3109 // an integer. If the bit width of the source and target integer types is the
3110 // same, just erase the cast. If the target type is wider, sign-extend the
3111 // value, otherwise truncate it.
3112 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
3113 using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
3114
3115 LogicalResult
matchAndRewrite__anonf5dcde621311::IndexCastOpLowering3116 matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
3117 ConversionPatternRewriter &rewriter) const override {
3118 IndexCastOpAdaptor transformed(operands);
3119
3120 auto targetType =
3121 typeConverter->convertType(indexCastOp.getResult().getType())
3122 .cast<LLVM::LLVMType>();
3123 auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
3124 unsigned targetBits = targetType.getIntegerBitWidth();
3125 unsigned sourceBits = sourceType.getIntegerBitWidth();
3126
3127 if (targetBits == sourceBits)
3128 rewriter.replaceOp(indexCastOp, transformed.in());
3129 else if (targetBits < sourceBits)
3130 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
3131 transformed.in());
3132 else
3133 rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
3134 transformed.in());
3135 return success();
3136 }
3137 };
3138
3139 // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two
3140 // enums share the numerical values so just cast.
3141 template <typename LLVMPredType, typename StdPredType>
convertCmpPredicate(StdPredType pred)3142 static LLVMPredType convertCmpPredicate(StdPredType pred) {
3143 return static_cast<LLVMPredType>(pred);
3144 }
3145
3146 struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
3147 using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
3148
3149 LogicalResult
matchAndRewrite__anonf5dcde621311::CmpIOpLowering3150 matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
3151 ConversionPatternRewriter &rewriter) const override {
3152 CmpIOpAdaptor transformed(operands);
3153
3154 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
3155 cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
3156 rewriter.getI64IntegerAttr(static_cast<int64_t>(
3157 convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
3158 transformed.lhs(), transformed.rhs());
3159
3160 return success();
3161 }
3162 };
3163
3164 struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
3165 using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
3166
3167 LogicalResult
matchAndRewrite__anonf5dcde621311::CmpFOpLowering3168 matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
3169 ConversionPatternRewriter &rewriter) const override {
3170 CmpFOpAdaptor transformed(operands);
3171
3172 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
3173 cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
3174 rewriter.getI64IntegerAttr(static_cast<int64_t>(
3175 convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
3176 transformed.lhs(), transformed.rhs());
3177
3178 return success();
3179 }
3180 };
3181
3182 struct SIToFPLowering
3183 : public OneToOneConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp> {
3184 using Super::Super;
3185 };
3186
3187 struct UIToFPLowering
3188 : public OneToOneConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp> {
3189 using Super::Super;
3190 };
3191
3192 struct FPExtLowering
3193 : public OneToOneConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp> {
3194 using Super::Super;
3195 };
3196
3197 struct FPToSILowering
3198 : public OneToOneConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp> {
3199 using Super::Super;
3200 };
3201
3202 struct FPToUILowering
3203 : public OneToOneConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp> {
3204 using Super::Super;
3205 };
3206
3207 struct FPTruncLowering
3208 : public OneToOneConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp> {
3209 using Super::Super;
3210 };
3211
3212 struct SignExtendIOpLowering
3213 : public OneToOneConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp> {
3214 using Super::Super;
3215 };
3216
3217 struct TruncateIOpLowering
3218 : public OneToOneConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp> {
3219 using Super::Super;
3220 };
3221
3222 struct ZeroExtendIOpLowering
3223 : public OneToOneConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp> {
3224 using Super::Super;
3225 };
3226
3227 // Base class for LLVM IR lowering terminator operations with successors.
3228 template <typename SourceOp, typename TargetOp>
3229 struct OneToOneLLVMTerminatorLowering
3230 : public ConvertOpToLLVMPattern<SourceOp> {
3231 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3232 using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
3233
3234 LogicalResult
matchAndRewrite__anonf5dcde621311::OneToOneLLVMTerminatorLowering3235 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
3236 ConversionPatternRewriter &rewriter) const override {
3237 rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
3238 op.getAttrs());
3239 return success();
3240 }
3241 };
3242
3243 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
3244 // `ReturnOp` interacts with the function signature and must have as many
3245 // operands as the function has return values. Because in LLVM IR, functions
3246 // can only return 0 or 1 value, we pack multiple values into a structure type.
3247 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
3248 // necessary before returning it
3249 struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
3250 using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
3251
3252 LogicalResult
matchAndRewrite__anonf5dcde621311::ReturnOpLowering3253 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
3254 ConversionPatternRewriter &rewriter) const override {
3255 Location loc = op.getLoc();
3256 unsigned numArguments = op.getNumOperands();
3257 SmallVector<Value, 4> updatedOperands;
3258
3259 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
3260 // For the bare-ptr calling convention, extract the aligned pointer to
3261 // be returned from the memref descriptor.
3262 for (auto it : llvm::zip(op->getOperands(), operands)) {
3263 Type oldTy = std::get<0>(it).getType();
3264 Value newOperand = std::get<1>(it);
3265 if (oldTy.isa<MemRefType>()) {
3266 MemRefDescriptor memrefDesc(newOperand);
3267 newOperand = memrefDesc.alignedPtr(rewriter, loc);
3268 } else if (oldTy.isa<UnrankedMemRefType>()) {
3269 // Unranked memref is not supported in the bare pointer calling
3270 // convention.
3271 return failure();
3272 }
3273 updatedOperands.push_back(newOperand);
3274 }
3275 } else {
3276 updatedOperands = llvm::to_vector<4>(operands);
3277 copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(),
3278 op.getOperands().getTypes(), updatedOperands,
3279 /*toDynamic=*/true);
3280 }
3281
3282 // If ReturnOp has 0 or 1 operand, create it and return immediately.
3283 if (numArguments == 0) {
3284 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
3285 op.getAttrs());
3286 return success();
3287 }
3288 if (numArguments == 1) {
3289 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
3290 op, TypeRange(), updatedOperands, op.getAttrs());
3291 return success();
3292 }
3293
3294 // Otherwise, we need to pack the arguments into an LLVM struct type before
3295 // returning.
3296 auto packedType = getTypeConverter()->packFunctionResults(
3297 llvm::to_vector<4>(op.getOperandTypes()));
3298
3299 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
3300 for (unsigned i = 0; i < numArguments; ++i) {
3301 packed = rewriter.create<LLVM::InsertValueOp>(
3302 loc, packedType, packed, updatedOperands[i],
3303 rewriter.getI64ArrayAttr(i));
3304 }
3305 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
3306 op.getAttrs());
3307 return success();
3308 }
3309 };
3310
3311 // FIXME: this should be tablegen'ed as well.
3312 struct BranchOpLowering
3313 : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
3314 using Super::Super;
3315 };
3316 struct CondBranchOpLowering
3317 : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
3318 using Super::Super;
3319 };
3320
3321 // The Splat operation is lowered to an insertelement + a shufflevector
3322 // operation. Splat to only 1-d vector result types are lowered.
3323 struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
3324 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
3325
3326 LogicalResult
matchAndRewrite__anonf5dcde621311::SplatOpLowering3327 matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
3328 ConversionPatternRewriter &rewriter) const override {
3329 VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
3330 if (!resultType || resultType.getRank() != 1)
3331 return failure();
3332
3333 // First insert it into an undef vector so we can shuffle it.
3334 auto vectorType = typeConverter->convertType(splatOp.getType());
3335 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
3336 auto zero = rewriter.create<LLVM::ConstantOp>(
3337 splatOp.getLoc(),
3338 typeConverter->convertType(rewriter.getIntegerType(32)),
3339 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
3340
3341 auto v = rewriter.create<LLVM::InsertElementOp>(
3342 splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero);
3343
3344 int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
3345 SmallVector<int32_t, 4> zeroValues(width, 0);
3346
3347 // Shuffle the value across the desired number of elements.
3348 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
3349 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
3350 zeroAttrs);
3351 return success();
3352 }
3353 };
3354
3355 // The Splat operation is lowered to an insertelement + a shufflevector
3356 // operation. Splat to only 2+-d vector result types are lowered by the
3357 // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
3358 struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
3359 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
3360
3361 LogicalResult
matchAndRewrite__anonf5dcde621311::SplatNdOpLowering3362 matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
3363 ConversionPatternRewriter &rewriter) const override {
3364 SplatOp::Adaptor adaptor(operands);
3365 VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
3366 if (!resultType || resultType.getRank() == 1)
3367 return failure();
3368
3369 // First insert it into an undef vector so we can shuffle it.
3370 auto loc = splatOp.getLoc();
3371 auto vectorTypeInfo =
3372 extractNDVectorTypeInfo(resultType, *getTypeConverter());
3373 auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
3374 auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
3375 if (!llvmArrayTy || !llvmVectorTy)
3376 return failure();
3377
3378 // Construct returned value.
3379 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
3380
3381 // Construct a 1-D vector with the splatted value that we insert in all the
3382 // places within the returned descriptor.
3383 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
3384 auto zero = rewriter.create<LLVM::ConstantOp>(
3385 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
3386 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
3387 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
3388 adaptor.input(), zero);
3389
3390 // Shuffle the value across the desired number of elements.
3391 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
3392 SmallVector<int32_t, 4> zeroValues(width, 0);
3393 ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
3394 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
3395
3396 // Iterate of linear index, convert to coords space and insert splatted 1-D
3397 // vector in each position.
3398 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
3399 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
3400 position);
3401 });
3402 rewriter.replaceOp(splatOp, desc);
3403 return success();
3404 }
3405 };
3406
3407 /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)3408 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
3409 return llvm::to_vector<4>(
3410 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
3411 return a.cast<IntegerAttr>().getInt();
3412 }));
3413 }
3414
3415 /// Conversion pattern that transforms a subview op into:
3416 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
3417 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
3418 /// and stride.
3419 /// The subview op is replaced by the descriptor.
3420 struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
3421 using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
3422
3423 LogicalResult
matchAndRewrite__anonf5dcde621311::SubViewOpLowering3424 matchAndRewrite(SubViewOp subViewOp, ArrayRef<Value> operands,
3425 ConversionPatternRewriter &rewriter) const override {
3426 auto loc = subViewOp.getLoc();
3427
3428 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
3429 auto sourceElementTy =
3430 typeConverter->convertType(sourceMemRefType.getElementType())
3431 .dyn_cast_or_null<LLVM::LLVMType>();
3432
3433 auto viewMemRefType = subViewOp.getType();
3434 auto inferredType = SubViewOp::inferResultType(
3435 subViewOp.getSourceType(),
3436 extractFromI64ArrayAttr(subViewOp.static_offsets()),
3437 extractFromI64ArrayAttr(subViewOp.static_sizes()),
3438 extractFromI64ArrayAttr(subViewOp.static_strides()))
3439 .cast<MemRefType>();
3440 auto targetElementTy =
3441 typeConverter->convertType(viewMemRefType.getElementType())
3442 .dyn_cast<LLVM::LLVMType>();
3443 auto targetDescTy = typeConverter->convertType(viewMemRefType)
3444 .dyn_cast_or_null<LLVM::LLVMType>();
3445 if (!sourceElementTy || !targetDescTy)
3446 return failure();
3447
3448 // Extract the offset and strides from the type.
3449 int64_t offset;
3450 SmallVector<int64_t, 4> strides;
3451 auto successStrides = getStridesAndOffset(inferredType, strides, offset);
3452 if (failed(successStrides))
3453 return failure();
3454
3455 // Create the descriptor.
3456 if (!operands.front().getType().isa<LLVM::LLVMType>())
3457 return failure();
3458 MemRefDescriptor sourceMemRef(operands.front());
3459 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
3460
3461 // Copy the buffer pointer from the old descriptor to the new one.
3462 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
3463 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3464 loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
3465 extracted);
3466 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
3467
3468 // Copy the buffer pointer from the old descriptor to the new one.
3469 extracted = sourceMemRef.alignedPtr(rewriter, loc);
3470 bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3471 loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
3472 extracted);
3473 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
3474
3475 auto shape = viewMemRefType.getShape();
3476 auto inferredShape = inferredType.getShape();
3477 size_t inferredShapeRank = inferredShape.size();
3478 size_t resultShapeRank = shape.size();
3479 SmallVector<bool, 4> mask =
3480 computeRankReductionMask(inferredShape, shape).getValue();
3481
3482 // Extract strides needed to compute offset.
3483 SmallVector<Value, 4> strideValues;
3484 strideValues.reserve(inferredShapeRank);
3485 for (unsigned i = 0; i < inferredShapeRank; ++i)
3486 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
3487
3488 // Offset.
3489 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
3490 if (!ShapedType::isDynamicStrideOrOffset(offset)) {
3491 targetMemRef.setConstantOffset(rewriter, loc, offset);
3492 } else {
3493 Value baseOffset = sourceMemRef.offset(rewriter, loc);
3494 for (unsigned i = 0; i < inferredShapeRank; ++i) {
3495 Value offset =
3496 subViewOp.isDynamicOffset(i)
3497 ? operands[subViewOp.getIndexOfDynamicOffset(i)]
3498 : rewriter.create<LLVM::ConstantOp>(
3499 loc, llvmIndexType,
3500 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
3501 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
3502 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
3503 }
3504 targetMemRef.setOffset(rewriter, loc, baseOffset);
3505 }
3506
3507 // Update sizes and strides.
3508 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
3509 i >= 0 && j >= 0; --i) {
3510 if (!mask[i])
3511 continue;
3512
3513 Value size =
3514 subViewOp.isDynamicSize(i)
3515 ? operands[subViewOp.getIndexOfDynamicSize(i)]
3516 : rewriter.create<LLVM::ConstantOp>(
3517 loc, llvmIndexType,
3518 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
3519 targetMemRef.setSize(rewriter, loc, j, size);
3520 Value stride;
3521 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
3522 stride = rewriter.create<LLVM::ConstantOp>(
3523 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
3524 } else {
3525 stride =
3526 subViewOp.isDynamicStride(i)
3527 ? operands[subViewOp.getIndexOfDynamicStride(i)]
3528 : rewriter.create<LLVM::ConstantOp>(
3529 loc, llvmIndexType,
3530 rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i)));
3531 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
3532 }
3533 targetMemRef.setStride(rewriter, loc, j, stride);
3534 j--;
3535 }
3536
3537 rewriter.replaceOp(subViewOp, {targetMemRef});
3538 return success();
3539 }
3540 };
3541
3542 /// Conversion pattern that transforms a transpose op into:
3543 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
3544 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
3545 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
3546 /// and stride. Size and stride are permutations of the original values.
3547 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
3548 /// The transpose op is replaced by the alloca'ed pointer.
3549 class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
3550 public:
3551 using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
3552
3553 LogicalResult
matchAndRewrite(TransposeOp transposeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const3554 matchAndRewrite(TransposeOp transposeOp, ArrayRef<Value> operands,
3555 ConversionPatternRewriter &rewriter) const override {
3556 auto loc = transposeOp.getLoc();
3557 TransposeOpAdaptor adaptor(operands);
3558 MemRefDescriptor viewMemRef(adaptor.in());
3559
3560 // No permutation, early exit.
3561 if (transposeOp.permutation().isIdentity())
3562 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
3563
3564 auto targetMemRef = MemRefDescriptor::undef(
3565 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
3566
3567 // Copy the base and aligned pointers from the old descriptor to the new
3568 // one.
3569 targetMemRef.setAllocatedPtr(rewriter, loc,
3570 viewMemRef.allocatedPtr(rewriter, loc));
3571 targetMemRef.setAlignedPtr(rewriter, loc,
3572 viewMemRef.alignedPtr(rewriter, loc));
3573
3574 // Copy the offset pointer from the old descriptor to the new one.
3575 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
3576
3577 // Iterate over the dimensions and apply size/stride permutation.
3578 for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
3579 int sourcePos = en.index();
3580 int targetPos = en.value().cast<AffineDimExpr>().getPosition();
3581 targetMemRef.setSize(rewriter, loc, targetPos,
3582 viewMemRef.size(rewriter, loc, sourcePos));
3583 targetMemRef.setStride(rewriter, loc, targetPos,
3584 viewMemRef.stride(rewriter, loc, sourcePos));
3585 }
3586
3587 rewriter.replaceOp(transposeOp, {targetMemRef});
3588 return success();
3589 }
3590 };
3591
3592 /// Conversion pattern that transforms an op into:
3593 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
3594 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
3595 /// and stride.
3596 /// The view op is replaced by the descriptor.
3597 struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
3598 using ConvertOpToLLVMPattern<ViewOp>::ConvertOpToLLVMPattern;
3599
3600 // Build and return the value for the idx^th shape dimension, either by
3601 // returning the constant shape dimension or counting the proper dynamic size.
getSize__anonf5dcde621311::ViewOpLowering3602 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
3603 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
3604 unsigned idx) const {
3605 assert(idx < shape.size());
3606 if (!ShapedType::isDynamic(shape[idx]))
3607 return createIndexConstant(rewriter, loc, shape[idx]);
3608 // Count the number of dynamic dims in range [0, idx]
3609 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
3610 return ShapedType::isDynamic(v);
3611 });
3612 return dynamicSizes[nDynamic];
3613 }
3614
3615 // Build and return the idx^th stride, either by returning the constant stride
3616 // or by computing the dynamic stride from the current `runningStride` and
3617 // `nextSize`. The caller should keep a running stride and update it with the
3618 // result returned by this function.
getStride__anonf5dcde621311::ViewOpLowering3619 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
3620 ArrayRef<int64_t> strides, Value nextSize,
3621 Value runningStride, unsigned idx) const {
3622 assert(idx < strides.size());
3623 if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
3624 return createIndexConstant(rewriter, loc, strides[idx]);
3625 if (nextSize)
3626 return runningStride
3627 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
3628 : nextSize;
3629 assert(!runningStride);
3630 return createIndexConstant(rewriter, loc, 1);
3631 }
3632
3633 LogicalResult
matchAndRewrite__anonf5dcde621311::ViewOpLowering3634 matchAndRewrite(ViewOp viewOp, ArrayRef<Value> operands,
3635 ConversionPatternRewriter &rewriter) const override {
3636 auto loc = viewOp.getLoc();
3637 ViewOpAdaptor adaptor(operands);
3638
3639 auto viewMemRefType = viewOp.getType();
3640 auto targetElementTy =
3641 typeConverter->convertType(viewMemRefType.getElementType())
3642 .dyn_cast<LLVM::LLVMType>();
3643 auto targetDescTy =
3644 typeConverter->convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
3645 if (!targetDescTy)
3646 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
3647 failure();
3648
3649 int64_t offset;
3650 SmallVector<int64_t, 4> strides;
3651 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
3652 if (failed(successStrides))
3653 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
3654 assert(offset == 0 && "expected offset to be 0");
3655
3656 // Create the descriptor.
3657 MemRefDescriptor sourceMemRef(adaptor.source());
3658 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
3659
3660 // Field 1: Copy the allocated pointer, used for malloc/free.
3661 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
3662 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
3663 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3664 loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
3665 allocatedPtr);
3666 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
3667
3668 // Field 2: Copy the actual aligned pointer to payload.
3669 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
3670 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
3671 alignedPtr, adaptor.byte_shift());
3672 bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3673 loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
3674 alignedPtr);
3675 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
3676
3677 // Field 3: The offset in the resulting type must be 0. This is because of
3678 // the type change: an offset on srcType* may not be expressible as an
3679 // offset on dstType*.
3680 targetMemRef.setOffset(rewriter, loc,
3681 createIndexConstant(rewriter, loc, offset));
3682
3683 // Early exit for 0-D corner case.
3684 if (viewMemRefType.getRank() == 0)
3685 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
3686
3687 // Fields 4 and 5: Update sizes and strides.
3688 if (strides.back() != 1)
3689 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
3690 failure();
3691 Value stride = nullptr, nextSize = nullptr;
3692 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
3693 // Update size.
3694 Value size =
3695 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
3696 targetMemRef.setSize(rewriter, loc, i, size);
3697 // Update stride.
3698 stride = getStride(rewriter, loc, strides, nextSize, stride, i);
3699 targetMemRef.setStride(rewriter, loc, i, stride);
3700 nextSize = size;
3701 }
3702
3703 rewriter.replaceOp(viewOp, {targetMemRef});
3704 return success();
3705 }
3706 };
3707
3708 struct AssumeAlignmentOpLowering
3709 : public ConvertOpToLLVMPattern<AssumeAlignmentOp> {
3710 using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
3711
3712 LogicalResult
matchAndRewrite__anonf5dcde621311::AssumeAlignmentOpLowering3713 matchAndRewrite(AssumeAlignmentOp op, ArrayRef<Value> operands,
3714 ConversionPatternRewriter &rewriter) const override {
3715 AssumeAlignmentOp::Adaptor transformed(operands);
3716 Value memref = transformed.memref();
3717 unsigned alignment = op.alignment();
3718 auto loc = op.getLoc();
3719
3720 MemRefDescriptor memRefDescriptor(memref);
3721 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
3722
3723 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
3724 // the asserted memref.alignedPtr isn't used anywhere else, as the real
3725 // users like load/store/views always re-extract memref.alignedPtr as they
3726 // get lowered.
3727 //
3728 // This relies on LLVM's CSE optimization (potentially after SROA), since
3729 // after CSE all memref.alignedPtr instances get de-duplicated into the same
3730 // pointer SSA value.
3731 auto intPtrType =
3732 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
3733 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
3734 Value mask =
3735 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
3736 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
3737 rewriter.create<LLVM::AssumeOp>(
3738 loc, rewriter.create<LLVM::ICmpOp>(
3739 loc, LLVM::ICmpPredicate::eq,
3740 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
3741
3742 rewriter.eraseOp(op);
3743 return success();
3744 }
3745 };
3746
3747 } // namespace
3748
3749 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
3750 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
matchSimpleAtomicOp(AtomicRMWOp atomicOp)3751 static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
3752 switch (atomicOp.kind()) {
3753 case AtomicRMWKind::addf:
3754 return LLVM::AtomicBinOp::fadd;
3755 case AtomicRMWKind::addi:
3756 return LLVM::AtomicBinOp::add;
3757 case AtomicRMWKind::assign:
3758 return LLVM::AtomicBinOp::xchg;
3759 case AtomicRMWKind::maxs:
3760 return LLVM::AtomicBinOp::max;
3761 case AtomicRMWKind::maxu:
3762 return LLVM::AtomicBinOp::umax;
3763 case AtomicRMWKind::mins:
3764 return LLVM::AtomicBinOp::min;
3765 case AtomicRMWKind::minu:
3766 return LLVM::AtomicBinOp::umin;
3767 default:
3768 return llvm::None;
3769 }
3770 llvm_unreachable("Invalid AtomicRMWKind");
3771 }
3772
3773 namespace {
3774
3775 struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
3776 using Base::Base;
3777
3778 LogicalResult
matchAndRewrite__anonf5dcde621811::AtomicRMWOpLowering3779 matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
3780 ConversionPatternRewriter &rewriter) const override {
3781 if (failed(match(atomicOp)))
3782 return failure();
3783 auto maybeKind = matchSimpleAtomicOp(atomicOp);
3784 if (!maybeKind)
3785 return failure();
3786 AtomicRMWOp::Adaptor adaptor(operands);
3787 auto resultType = adaptor.value().getType();
3788 auto memRefType = atomicOp.getMemRefType();
3789 auto dataPtr =
3790 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
3791 adaptor.indices(), rewriter);
3792 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
3793 atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
3794 LLVM::AtomicOrdering::acq_rel);
3795 return success();
3796 }
3797 };
3798
3799 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
3800 /// retried until it succeeds in atomically storing a new value into memory.
3801 ///
3802 /// +---------------------------------+
3803 /// | <code before the AtomicRMWOp> |
3804 /// | <compute initial %loaded> |
3805 /// | br loop(%loaded) |
3806 /// +---------------------------------+
3807 /// |
3808 /// -------| |
3809 /// | v v
3810 /// | +--------------------------------+
3811 /// | | loop(%loaded): |
3812 /// | | <body contents> |
3813 /// | | %pair = cmpxchg |
3814 /// | | %ok = %pair[0] |
3815 /// | | %new = %pair[1] |
3816 /// | | cond_br %ok, end, loop(%new) |
3817 /// | +--------------------------------+
3818 /// | | |
3819 /// |----------- |
3820 /// v
3821 /// +--------------------------------+
3822 /// | end: |
3823 /// | <code after the AtomicRMWOp> |
3824 /// +--------------------------------+
3825 ///
3826 struct GenericAtomicRMWOpLowering
3827 : public LoadStoreOpLowering<GenericAtomicRMWOp> {
3828 using Base::Base;
3829
3830 LogicalResult
matchAndRewrite__anonf5dcde621811::GenericAtomicRMWOpLowering3831 matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
3832 ConversionPatternRewriter &rewriter) const override {
3833
3834 auto loc = atomicOp.getLoc();
3835 GenericAtomicRMWOp::Adaptor adaptor(operands);
3836 LLVM::LLVMType valueType =
3837 typeConverter->convertType(atomicOp.getResult().getType())
3838 .cast<LLVM::LLVMType>();
3839
3840 // Split the block into initial, loop, and ending parts.
3841 auto *initBlock = rewriter.getInsertionBlock();
3842 auto *loopBlock =
3843 rewriter.createBlock(initBlock->getParent(),
3844 std::next(Region::iterator(initBlock)), valueType);
3845 auto *endBlock = rewriter.createBlock(
3846 loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
3847
3848 // Operations range to be moved to `endBlock`.
3849 auto opsToMoveStart = atomicOp->getIterator();
3850 auto opsToMoveEnd = initBlock->back().getIterator();
3851
3852 // Compute the loaded value and branch to the loop block.
3853 rewriter.setInsertionPointToEnd(initBlock);
3854 auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
3855 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
3856 adaptor.indices(), rewriter);
3857 Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
3858 rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
3859
3860 // Prepare the body of the loop block.
3861 rewriter.setInsertionPointToStart(loopBlock);
3862
3863 // Clone the GenericAtomicRMWOp region and extract the result.
3864 auto loopArgument = loopBlock->getArgument(0);
3865 BlockAndValueMapping mapping;
3866 mapping.map(atomicOp.getCurrentValue(), loopArgument);
3867 Block &entryBlock = atomicOp.body().front();
3868 for (auto &nestedOp : entryBlock.without_terminator()) {
3869 Operation *clone = rewriter.clone(nestedOp, mapping);
3870 mapping.map(nestedOp.getResults(), clone->getResults());
3871 }
3872 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
3873
3874 // Prepare the epilog of the loop block.
3875 // Append the cmpxchg op to the end of the loop block.
3876 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
3877 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
3878 auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
3879 auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
3880 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
3881 loc, pairType, dataPtr, loopArgument, result, successOrdering,
3882 failureOrdering);
3883 // Extract the %new_loaded and %ok values from the pair.
3884 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
3885 loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
3886 Value ok = rewriter.create<LLVM::ExtractValueOp>(
3887 loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
3888
3889 // Conditionally branch to the end or back to the loop depending on %ok.
3890 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
3891 loopBlock, newLoaded);
3892
3893 rewriter.setInsertionPointToEnd(endBlock);
3894 moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
3895 std::next(opsToMoveEnd), rewriter);
3896
3897 // The 'result' of the atomic_rmw op is the newly loaded value.
3898 rewriter.replaceOp(atomicOp, {newLoaded});
3899
3900 return success();
3901 }
3902
3903 private:
3904 // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anonf5dcde621811::GenericAtomicRMWOpLowering3905 void moveOpsRange(ValueRange oldResult, ValueRange newResult,
3906 Block::iterator start, Block::iterator end,
3907 ConversionPatternRewriter &rewriter) const {
3908 BlockAndValueMapping mapping;
3909 mapping.map(oldResult, newResult);
3910 SmallVector<Operation *, 2> opsToErase;
3911 for (auto it = start; it != end; ++it) {
3912 rewriter.clone(*it, mapping);
3913 opsToErase.push_back(&*it);
3914 }
3915 for (auto *it : opsToErase)
3916 rewriter.eraseOp(it);
3917 }
3918 };
3919
3920 } // namespace
3921
3922 /// Collect a set of patterns to convert from the Standard dialect to LLVM.
populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)3923 void mlir::populateStdToLLVMNonMemoryConversionPatterns(
3924 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
3925 // FIXME: this should be tablegen'ed
3926 // clang-format off
3927 patterns.insert<
3928 AbsFOpLowering,
3929 AddCFOpLowering,
3930 AddFOpLowering,
3931 AddIOpLowering,
3932 AllocaOpLowering,
3933 AndOpLowering,
3934 AssertOpLowering,
3935 AtomicRMWOpLowering,
3936 BranchOpLowering,
3937 CallIndirectOpLowering,
3938 CallOpLowering,
3939 CeilFOpLowering,
3940 CmpFOpLowering,
3941 CmpIOpLowering,
3942 CondBranchOpLowering,
3943 CopySignOpLowering,
3944 CosOpLowering,
3945 ConstantOpLowering,
3946 CreateComplexOpLowering,
3947 DialectCastOpLowering,
3948 DivFOpLowering,
3949 ExpOpLowering,
3950 Exp2OpLowering,
3951 FloorFOpLowering,
3952 GenericAtomicRMWOpLowering,
3953 LogOpLowering,
3954 Log10OpLowering,
3955 Log2OpLowering,
3956 FPExtLowering,
3957 FPToSILowering,
3958 FPToUILowering,
3959 FPTruncLowering,
3960 ImOpLowering,
3961 IndexCastOpLowering,
3962 MulFOpLowering,
3963 MulIOpLowering,
3964 NegFOpLowering,
3965 OrOpLowering,
3966 PrefetchOpLowering,
3967 ReOpLowering,
3968 RemFOpLowering,
3969 ReturnOpLowering,
3970 RsqrtOpLowering,
3971 SIToFPLowering,
3972 SelectOpLowering,
3973 ShiftLeftOpLowering,
3974 SignExtendIOpLowering,
3975 SignedDivIOpLowering,
3976 SignedRemIOpLowering,
3977 SignedShiftRightOpLowering,
3978 SinOpLowering,
3979 SplatOpLowering,
3980 SplatNdOpLowering,
3981 SqrtOpLowering,
3982 SubCFOpLowering,
3983 SubFOpLowering,
3984 SubIOpLowering,
3985 TruncateIOpLowering,
3986 UIToFPLowering,
3987 UnsignedDivIOpLowering,
3988 UnsignedRemIOpLowering,
3989 UnsignedShiftRightOpLowering,
3990 XOrOpLowering,
3991 ZeroExtendIOpLowering>(converter);
3992 // clang-format on
3993 }
3994
populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)3995 void mlir::populateStdToLLVMMemoryConversionPatterns(
3996 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
3997 // clang-format off
3998 patterns.insert<
3999 AssumeAlignmentOpLowering,
4000 DeallocOpLowering,
4001 DimOpLowering,
4002 GlobalMemrefOpLowering,
4003 GetGlobalMemrefOpLowering,
4004 LoadOpLowering,
4005 MemRefCastOpLowering,
4006 MemRefReinterpretCastOpLowering,
4007 MemRefReshapeOpLowering,
4008 RankOpLowering,
4009 StoreOpLowering,
4010 SubViewOpLowering,
4011 TransposeOpLowering,
4012 ViewOpLowering>(converter);
4013 // clang-format on
4014 if (converter.getOptions().useAlignedAlloc)
4015 patterns.insert<AlignedAllocOpLowering>(converter);
4016 else
4017 patterns.insert<AllocOpLowering>(converter);
4018 }
4019
populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)4020 void mlir::populateStdToLLVMFuncOpConversionPattern(
4021 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
4022 if (converter.getOptions().useBarePtrCallConv)
4023 patterns.insert<BarePtrFuncOpConversion>(converter);
4024 else
4025 patterns.insert<FuncOpConversion>(converter);
4026 }
4027
populateStdToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)4028 void mlir::populateStdToLLVMConversionPatterns(
4029 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
4030 populateStdToLLVMFuncOpConversionPattern(converter, patterns);
4031 populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
4032 populateStdToLLVMMemoryConversionPatterns(converter, patterns);
4033 }
4034
4035 /// Convert a non-empty list of types to be returned from a function into a
4036 /// supported LLVM IR type. In particular, if more than one value is returned,
4037 /// create an LLVM IR structure type with elements that correspond to each of
4038 /// the MLIR types converted with `convertType`.
packFunctionResults(ArrayRef<Type> types)4039 Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
4040 assert(!types.empty() && "expected non-empty list of type");
4041
4042 if (types.size() == 1)
4043 return convertCallingConventionType(types.front());
4044
4045 SmallVector<LLVM::LLVMType, 8> resultTypes;
4046 resultTypes.reserve(types.size());
4047 for (auto t : types) {
4048 auto converted =
4049 convertCallingConventionType(t).dyn_cast_or_null<LLVM::LLVMType>();
4050 if (!converted)
4051 return {};
4052 resultTypes.push_back(converted);
4053 }
4054
4055 return LLVM::LLVMType::getStructTy(&getContext(), resultTypes);
4056 }
4057
promoteOneMemRefDescriptor(Location loc,Value operand,OpBuilder & builder)4058 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
4059 OpBuilder &builder) {
4060 auto *context = builder.getContext();
4061 auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext());
4062 auto indexType = IndexType::get(context);
4063 // Alloca with proper alignment. We do not expect optimizations of this
4064 // alloca op and so we omit allocating at the entry block.
4065 auto ptrType = operand.getType().cast<LLVM::LLVMType>().getPointerTo();
4066 Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
4067 IntegerAttr::get(indexType, 1));
4068 Value allocated =
4069 builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
4070 // Store into the alloca'ed descriptor.
4071 builder.create<LLVM::StoreOp>(loc, operand, allocated);
4072 return allocated;
4073 }
4074
promoteOperands(Location loc,ValueRange opOperands,ValueRange operands,OpBuilder & builder)4075 SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
4076 ValueRange opOperands,
4077 ValueRange operands,
4078 OpBuilder &builder) {
4079 SmallVector<Value, 4> promotedOperands;
4080 promotedOperands.reserve(operands.size());
4081 for (auto it : llvm::zip(opOperands, operands)) {
4082 auto operand = std::get<0>(it);
4083 auto llvmOperand = std::get<1>(it);
4084
4085 if (options.useBarePtrCallConv) {
4086 // For the bare-ptr calling convention, we only have to extract the
4087 // aligned pointer of a memref.
4088 if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
4089 MemRefDescriptor desc(llvmOperand);
4090 llvmOperand = desc.alignedPtr(builder, loc);
4091 } else if (operand.getType().isa<UnrankedMemRefType>()) {
4092 llvm_unreachable("Unranked memrefs are not supported");
4093 }
4094 } else {
4095 if (operand.getType().isa<UnrankedMemRefType>()) {
4096 UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
4097 promotedOperands);
4098 continue;
4099 }
4100 if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
4101 MemRefDescriptor::unpack(builder, loc, llvmOperand,
4102 operand.getType().cast<MemRefType>(),
4103 promotedOperands);
4104 continue;
4105 }
4106 }
4107
4108 promotedOperands.push_back(llvmOperand);
4109 }
4110 return promotedOperands;
4111 }
4112
4113 namespace {
4114 /// A pass converting MLIR operations into the LLVM IR dialect.
4115 struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
4116 LLVMLoweringPass() = default;
LLVMLoweringPass__anonf5dcde621911::LLVMLoweringPass4117 LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers,
4118 unsigned indexBitwidth, bool useAlignedAlloc,
4119 const llvm::DataLayout &dataLayout) {
4120 this->useBarePtrCallConv = useBarePtrCallConv;
4121 this->emitCWrappers = emitCWrappers;
4122 this->indexBitwidth = indexBitwidth;
4123 this->useAlignedAlloc = useAlignedAlloc;
4124 this->dataLayout = dataLayout.getStringRepresentation();
4125 }
4126
4127 /// Run the dialect converter on the module.
runOnOperation__anonf5dcde621911::LLVMLoweringPass4128 void runOnOperation() override {
4129 if (useBarePtrCallConv && emitCWrappers) {
4130 getOperation().emitError()
4131 << "incompatible conversion options: bare-pointer calling convention "
4132 "and C wrapper emission";
4133 signalPassFailure();
4134 return;
4135 }
4136 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
4137 this->dataLayout, [this](const Twine &message) {
4138 getOperation().emitError() << message.str();
4139 }))) {
4140 signalPassFailure();
4141 return;
4142 }
4143
4144 ModuleOp m = getOperation();
4145
4146 LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers,
4147 indexBitwidth, useAlignedAlloc,
4148 llvm::DataLayout(this->dataLayout)};
4149 LLVMTypeConverter typeConverter(&getContext(), options);
4150
4151 OwningRewritePatternList patterns;
4152 populateStdToLLVMConversionPatterns(typeConverter, patterns);
4153
4154 LLVMConversionTarget target(getContext());
4155 if (failed(applyPartialConversion(m, target, std::move(patterns))))
4156 signalPassFailure();
4157 m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
4158 StringAttr::get(this->dataLayout, m.getContext()));
4159 }
4160 };
4161 } // end namespace
4162
LLVMConversionTarget(MLIRContext & ctx)4163 mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
4164 : ConversionTarget(ctx) {
4165 this->addLegalDialect<LLVM::LLVMDialect>();
4166 this->addIllegalOp<LLVM::DialectCastOp>();
4167 this->addIllegalOp<TanhOp>();
4168 }
4169
4170 std::unique_ptr<OperationPass<ModuleOp>>
createLowerToLLVMPass(const LowerToLLVMOptions & options)4171 mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
4172 return std::make_unique<LLVMLoweringPass>(
4173 options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth,
4174 options.useAlignedAlloc, options.dataLayout);
4175 }
4176