1 //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a set of simple combiners for optimizing operations in
10 // the Toy dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "toy/Dialect.h"
17 #include <numeric>
18 using namespace mlir;
19 using namespace toy;
20
21 namespace {
22 /// Include the patterns defined in the Declarative Rewrite framework.
23 #include "ToyCombine.inc"
24 } // end anonymous namespace
25
26 /// Fold simple cast operations that return the same type as the input.
fold(ArrayRef<Attribute> operands)27 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
28 return mlir::impl::foldCastOp(*this);
29 }
30
31 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
32 /// optimizes the following scenario: transpose(transpose(x)) -> x
33 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
34 /// We register this pattern to match every toy.transpose in the IR.
35 /// The "benefit" is used by the framework to order the patterns and process
36 /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose37 SimplifyRedundantTranspose(mlir::MLIRContext *context)
38 : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
39
40 /// This method attempts to match a pattern and rewrite it. The rewriter
41 /// argument is the orchestrator of the sequence of rewrites. The pattern is
42 /// expected to interact with it to perform any changes to the IR from here.
43 mlir::LogicalResult
matchAndRewriteSimplifyRedundantTranspose44 matchAndRewrite(TransposeOp op,
45 mlir::PatternRewriter &rewriter) const override {
46 // Look through the input of the current transpose.
47 mlir::Value transposeInput = op.getOperand();
48 TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
49
50 // Input defined by another transpose? If not, no match.
51 if (!transposeInputOp)
52 return failure();
53
54 // Otherwise, we have a redundant transpose. Use the rewriter.
55 rewriter.replaceOp(op, {transposeInputOp.getOperand()});
56 return success();
57 }
58 };
59
60 /// Register our patterns as "canonicalization" patterns on the TransposeOp so
61 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)62 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
63 MLIRContext *context) {
64 results.insert<SimplifyRedundantTranspose>(context);
65 }
66
67 /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
68 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)69 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
70 MLIRContext *context) {
71 results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
72 FoldConstantReshapeOptPattern>(context);
73 }
74