//===- Bufferize.cpp - Bufferization for std ops --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements bufferization of std ops. // //===----------------------------------------------------------------------===// #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; namespace { class BufferizeDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(DimOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { DimOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.memrefOrTensor(), adaptor.index()); return success(); } }; } // namespace namespace { class BufferizeDynamicTensorFromElementsOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(DynamicTensorFromElementsOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Allocate memory. Location loc = op.getLoc(); DynamicTensorFromElementsOp::Adaptor transformed(operands); RankedTensorType tensorType = op.getType().cast(); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); Value result = rewriter.create(loc, memrefType, transformed.dynamicExtents()); // Collect loop bounds. int64_t rank = tensorType.getRank(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); SmallVector lowerBounds(rank, zero); SmallVector steps(rank, one); SmallVector upperBounds; int nextDynamicIndex = 0; for (int i = 0; i < rank; i++) { Value upperBound = tensorType.isDynamicDim(i) ? transformed.dynamicExtents()[nextDynamicIndex++] : rewriter.create(loc, memrefType.getDimSize(i)); upperBounds.push_back(upperBound); } // Generate tensor elements with a parallel loop. rewriter.create( loc, lowerBounds, upperBounds, steps, [&](OpBuilder &b, Location loc, ValueRange ivs) { BlockAndValueMapping mapping; mapping.map(op.body().getArguments(), ivs); for (auto &nestedOp : op.getBody()->without_terminator()) b.clone(nestedOp, mapping); auto yieldOp = cast(op.getBody()->getTerminator()); b.create(loc, mapping.lookup(yieldOp.value()), result, ivs); b.create(loc); }); rewriter.replaceOp(op, {result}); return success(); } }; } // namespace namespace { class BufferizeExtractElementOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ExtractElementOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ExtractElementOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.aggregate(), adaptor.indices()); return success(); } }; } // namespace namespace { class BufferizeSelectOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!op.condition().getType().isa()) return rewriter.notifyMatchFailure(op, "requires scalar condition"); SelectOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); return success(); } }; } // namespace namespace { class BufferizeTensorCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorCastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, operands[0]); return success(); } }; } // namespace namespace { class BufferizeTensorFromElementsOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorFromElementsOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { int numberOfElements = op.elements().size(); auto resultType = MemRefType::get( {numberOfElements}, op.getType().cast().getElementType()); Value result = rewriter.create(op.getLoc(), resultType); for (auto element : llvm::enumerate(op.elements())) { Value index = rewriter.create(op.getLoc(), element.index()); rewriter.create(op.getLoc(), element.value(), result, index); } rewriter.replaceOp(op, {result}); return success(); } }; } // namespace void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< // clang-format off BufferizeDimOp, BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp, BufferizeSelectOp, BufferizeTensorCastOp, BufferizeTensorFromElementsOp // clang-format on >(typeConverter, context); } namespace { struct StdBufferizePass : public StdBufferizeBase { void runOnFunction() override { auto *context = &getContext(); BufferizeTypeConverter typeConverter; OwningRewritePatternList patterns; ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp(); // We only bufferize the case of tensor selected type and scalar condition, // as that boils down to a select over memref descriptors (don't need to // touch the data). target.addDynamicallyLegalOp([&](SelectOp op) { return typeConverter.isLegal(op.getType()) || !op.condition().getType().isa(); }); target.addDynamicallyLegalOp( [&](DimOp op) { return typeConverter.isLegal(op); }); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr mlir::createStdBufferizePass() { return std::make_unique(); }