1 //===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions for nodes of a tree structure for representing 10 // the general control flow within a pattern match. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 15 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 16 17 #include "Predicate.h" 18 #include "mlir/Dialect/PDL/IR/PDL.h" 19 #include "llvm/ADT/MapVector.h" 20 21 namespace mlir { 22 class ModuleOp; 23 24 namespace pdl_to_pdl_interp { 25 26 class MatcherNode; 27 28 /// A PositionalPredicate is a predicate that is associated with a specific 29 /// positional value. 30 struct PositionalPredicate { PositionalPredicatePositionalPredicate31 PositionalPredicate(Position *pos, 32 const PredicateBuilder::Predicate &predicate) 33 : position(pos), question(predicate.first), answer(predicate.second) {} 34 35 /// The position the predicate is applied to. 36 Position *position; 37 38 /// The question that the predicate applies. 39 Qualifier *question; 40 41 /// The expected answer of the predicate. 42 Qualifier *answer; 43 }; 44 45 //===----------------------------------------------------------------------===// 46 // MatcherNode 47 //===----------------------------------------------------------------------===// 48 49 /// This class represents the base of a predicate matcher node. 50 class MatcherNode { 51 public: 52 virtual ~MatcherNode() = default; 53 54 /// Given a module containing PDL pattern operations, generate a matcher tree 55 /// using the patterns within the given module and return the root matcher 56 /// node. `valueToPosition` is a map that is populated with the original 57 /// pdl values and their corresponding positions in the matcher tree. 58 static std::unique_ptr<MatcherNode> 59 generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 60 DenseMap<Value, Position *> &valueToPosition); 61 62 /// Returns the position on which the question predicate should be checked. getPosition()63 Position *getPosition() const { return position; } 64 65 /// Returns the predicate checked on this node. getQuestion()66 Qualifier *getQuestion() const { return question; } 67 68 /// Returns the node that should be visited if this, or a subsequent node 69 /// fails. getFailureNode()70 std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; } 71 72 /// Sets the node that should be visited if this, or a subsequent node fails. setFailureNode(std::unique_ptr<MatcherNode> node)73 void setFailureNode(std::unique_ptr<MatcherNode> node) { 74 failureNode = std::move(node); 75 } 76 77 /// Returns the unique type ID of this matcher instance. This should not be 78 /// used directly, and is provided to support type casting. getMatcherTypeID()79 TypeID getMatcherTypeID() const { return matcherTypeID; } 80 81 protected: 82 MatcherNode(TypeID matcherTypeID, Position *position = nullptr, 83 Qualifier *question = nullptr, 84 std::unique_ptr<MatcherNode> failureNode = nullptr); 85 86 private: 87 /// The position on which the predicate should be checked. 88 Position *position; 89 90 /// The predicate that is checked on the given position. 91 Qualifier *question; 92 93 /// The node to visit if this node fails. 94 std::unique_ptr<MatcherNode> failureNode; 95 96 /// An owning store for the failure node if it is owned by this node. 97 std::unique_ptr<MatcherNode> failureNodeStorage; 98 99 /// A unique identifier for the derived matcher node, used for type casting. 100 TypeID matcherTypeID; 101 }; 102 103 //===----------------------------------------------------------------------===// 104 // BoolNode 105 106 /// A BoolNode denotes a question with a boolean-like result. These nodes branch 107 /// to a single node on a successful result, otherwise defaulting to the failure 108 /// node. 109 struct BoolNode : public MatcherNode { 110 BoolNode(Position *position, Qualifier *question, Qualifier *answer, 111 std::unique_ptr<MatcherNode> successNode, 112 std::unique_ptr<MatcherNode> failureNode = nullptr); 113 114 /// Returns if the given matcher node is an instance of this class, used to 115 /// support type casting. classofBoolNode116 static bool classof(const MatcherNode *node) { 117 return node->getMatcherTypeID() == TypeID::get<BoolNode>(); 118 } 119 120 /// Returns the expected answer of this boolean node. getAnswerBoolNode121 Qualifier *getAnswer() const { return answer; } 122 123 /// Returns the node that should be visited on success. getSuccessNodeBoolNode124 std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; } 125 126 private: 127 /// The expected answer of this boolean node. 128 Qualifier *answer; 129 130 /// The next node if this node succeeds. Otherwise, go to the failure node. 131 std::unique_ptr<MatcherNode> successNode; 132 }; 133 134 //===----------------------------------------------------------------------===// 135 // ExitNode 136 137 /// An ExitNode is a special sentinel node that denotes the end of matcher. 138 struct ExitNode : public MatcherNode { ExitNodeExitNode139 ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {} 140 141 /// Returns if the given matcher node is an instance of this class, used to 142 /// support type casting. classofExitNode143 static bool classof(const MatcherNode *node) { 144 return node->getMatcherTypeID() == TypeID::get<ExitNode>(); 145 } 146 }; 147 148 //===----------------------------------------------------------------------===// 149 // SuccessNode 150 151 /// A SuccessNode denotes that a given high level pattern has successfully been 152 /// matched. This does not terminate the matcher, as there may be multiple 153 /// successful matches. 154 struct SuccessNode : public MatcherNode { 155 explicit SuccessNode(pdl::PatternOp pattern, 156 std::unique_ptr<MatcherNode> failureNode); 157 158 /// Returns if the given matcher node is an instance of this class, used to 159 /// support type casting. classofSuccessNode160 static bool classof(const MatcherNode *node) { 161 return node->getMatcherTypeID() == TypeID::get<SuccessNode>(); 162 } 163 164 /// Return the high level pattern operation that is matched with this node. getPatternSuccessNode165 pdl::PatternOp getPattern() const { return pattern; } 166 167 private: 168 /// The high level pattern operation that was successfully matched with this 169 /// node. 170 pdl::PatternOp pattern; 171 }; 172 173 //===----------------------------------------------------------------------===// 174 // SwitchNode 175 176 /// A SwitchNode denotes a question with multiple potential results. These nodes 177 /// branch to a specific node based on the result of the question. 178 struct SwitchNode : public MatcherNode { 179 SwitchNode(Position *position, Qualifier *question); 180 181 /// Returns if the given matcher node is an instance of this class, used to 182 /// support type casting. classofSwitchNode183 static bool classof(const MatcherNode *node) { 184 return node->getMatcherTypeID() == TypeID::get<SwitchNode>(); 185 } 186 187 /// Returns the children of this switch node. The children are contained 188 /// within a mapping between the various case answers to destination matcher 189 /// nodes. 190 using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>; getChildrenSwitchNode191 ChildMapT &getChildren() { return children; } 192 193 private: 194 /// Switch predicate "answers" select the child. Answers that are not found 195 /// default to the failure node. 196 ChildMapT children; 197 }; 198 199 } // end namespace pdl_to_pdl_interp 200 } // end namespace mlir 201 202 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 203