1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR. 10 // The byte-code is constructed from the PDL Interpreter dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_REWRITE_BYTECODE_H_ 15 #define MLIR_REWRITE_BYTECODE_H_ 16 17 #include "mlir/IR/PatternMatch.h" 18 19 namespace mlir { 20 namespace pdl_interp { 21 class RecordMatchOp; 22 } // end namespace pdl_interp 23 24 namespace detail { 25 class PDLByteCode; 26 27 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode 28 /// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of 29 /// indices into the bytecode. Correctness is checked with static asserts. 30 using ByteCodeField = uint16_t; 31 using ByteCodeAddr = uint32_t; 32 33 //===----------------------------------------------------------------------===// 34 // PDLByteCodePattern 35 //===----------------------------------------------------------------------===// 36 37 /// All of the data pertaining to a specific pattern within the bytecode. 38 class PDLByteCodePattern : public Pattern { 39 public: 40 static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, 41 ByteCodeAddr rewriterAddr); 42 43 /// Return the bytecode address of the rewriter for this pattern. getRewriterAddr()44 ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } 45 46 private: 47 template <typename... Args> PDLByteCodePattern(ByteCodeAddr rewriterAddr,Args &&...patternArgs)48 PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) 49 : Pattern(std::forward<Args>(patternArgs)...), 50 rewriterAddr(rewriterAddr) {} 51 52 /// The address of the rewriter for this pattern. 53 ByteCodeAddr rewriterAddr; 54 }; 55 56 //===----------------------------------------------------------------------===// 57 // PDLByteCodeMutableState 58 //===----------------------------------------------------------------------===// 59 60 /// This class contains the mutable state of a bytecode instance. This allows 61 /// for a bytecode instance to be cached and reused across various different 62 /// threads/drivers. 63 class PDLByteCodeMutableState { 64 public: 65 /// Initialize the state from a bytecode instance. 66 void initialize(PDLByteCode &bytecode); 67 68 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 69 /// to the position of the pattern within the range returned by 70 /// `PDLByteCode::getPatterns`. 71 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); 72 73 private: 74 /// Allow access to data fields. 75 friend class PDLByteCode; 76 77 /// The mutable block of memory used during the matching and rewriting phases 78 /// of the bytecode. 79 std::vector<const void *> memory; 80 81 /// The up-to-date benefits of the patterns held by the bytecode. The order 82 /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. 83 std::vector<PatternBenefit> currentPatternBenefits; 84 }; 85 86 //===----------------------------------------------------------------------===// 87 // PDLByteCode 88 //===----------------------------------------------------------------------===// 89 90 /// The bytecode class is also the interpreter. Contains the bytecode itself, 91 /// the static info, addresses of the rewriter functions, the interpreter 92 /// memory buffer, and the execution context. 93 class PDLByteCode { 94 public: 95 /// Each successful match returns a MatchResult, which contains information 96 /// necessary to execute the rewriter and indicates the originating pattern. 97 struct MatchResult { MatchResultMatchResult98 MatchResult(Location loc, const PDLByteCodePattern &pattern, 99 PatternBenefit benefit) 100 : location(loc), pattern(&pattern), benefit(benefit) {} 101 102 /// The location of operations to be replaced. 103 Location location; 104 /// Memory values defined in the matcher that are passed to the rewriter. 105 SmallVector<const void *, 4> values; 106 /// The originating pattern that was matched. This is always non-null, but 107 /// represented with a pointer to allow for assignment. 108 const PDLByteCodePattern *pattern; 109 /// The current benefit of the pattern that was matched. 110 PatternBenefit benefit; 111 }; 112 113 /// Create a ByteCode instance from the given module containing operations in 114 /// the PDL interpreter dialect. 115 PDLByteCode(ModuleOp module, 116 llvm::StringMap<PDLConstraintFunction> constraintFns, 117 llvm::StringMap<PDLCreateFunction> createFns, 118 llvm::StringMap<PDLRewriteFunction> rewriteFns); 119 120 /// Return the patterns held by the bytecode. getPatterns()121 ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; } 122 123 /// Initialize the given state such that it can be used to execute the current 124 /// bytecode. 125 void initializeMutableState(PDLByteCodeMutableState &state) const; 126 127 /// Run the pattern matcher on the given root operation, collecting the 128 /// matched patterns in `matches`. 129 void match(Operation *op, PatternRewriter &rewriter, 130 SmallVectorImpl<MatchResult> &matches, 131 PDLByteCodeMutableState &state) const; 132 133 /// Run the rewriter of the given pattern that was previously matched in 134 /// `match`. 135 void rewrite(PatternRewriter &rewriter, const MatchResult &match, 136 PDLByteCodeMutableState &state) const; 137 138 private: 139 /// Execute the given byte code starting at the provided instruction `inst`. 140 /// `matches` is an optional field provided when this function is executed in 141 /// a matching context. 142 void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, 143 PDLByteCodeMutableState &state, 144 SmallVectorImpl<MatchResult> *matches) const; 145 146 /// A vector containing pointers to unqiued data. The storage is intentionally 147 /// opaque such that we can store a wide range of data types. The types of 148 /// data stored here include: 149 /// * Attribute, Identifier, OperationName, Type 150 std::vector<const void *> uniquedData; 151 152 /// A vector containing the generated bytecode for the matcher. 153 SmallVector<ByteCodeField, 64> matcherByteCode; 154 155 /// A vector containing the generated bytecode for all of the rewriters. 156 SmallVector<ByteCodeField, 64> rewriterByteCode; 157 158 /// The set of patterns contained within the bytecode. 159 SmallVector<PDLByteCodePattern, 32> patterns; 160 161 /// A set of user defined functions invoked via PDL. 162 std::vector<PDLConstraintFunction> constraintFunctions; 163 std::vector<PDLCreateFunction> createFunctions; 164 std::vector<PDLRewriteFunction> rewriteFunctions; 165 166 /// The maximum memory index used by a value. 167 ByteCodeField maxValueMemoryIndex = 0; 168 }; 169 170 } // end namespace detail 171 } // end namespace mlir 172 173 #endif // MLIR_REWRITE_BYTECODE_H_ 174