• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Predicate.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
20 
21 using namespace mlir;
22 using namespace tblgen;
23 
24 // Construct a Predicate from a record.
Pred(const llvm::Record * record)25 Pred::Pred(const llvm::Record *record) : def(record) {
26   assert(def->isSubClassOf("Pred") &&
27          "must be a subclass of TableGen 'Pred' class");
28 }
29 
30 // Construct a Predicate from an initializer.
Pred(const llvm::Init * init)31 Pred::Pred(const llvm::Init *init) : def(nullptr) {
32   if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
33     def = defInit->getDef();
34 }
35 
getCondition() const36 std::string Pred::getCondition() const {
37   // Static dispatch to subclasses.
38   if (def->isSubClassOf("CombinedPred"))
39     return static_cast<const CombinedPred *>(this)->getConditionImpl();
40   if (def->isSubClassOf("CPred"))
41     return static_cast<const CPred *>(this)->getConditionImpl();
42   llvm_unreachable("Pred::getCondition must be overridden in subclasses");
43 }
44 
isCombined() const45 bool Pred::isCombined() const {
46   return def && def->isSubClassOf("CombinedPred");
47 }
48 
getLoc() const49 ArrayRef<llvm::SMLoc> Pred::getLoc() const { return def->getLoc(); }
50 
CPred(const llvm::Record * record)51 CPred::CPred(const llvm::Record *record) : Pred(record) {
52   assert(def->isSubClassOf("CPred") &&
53          "must be a subclass of Tablegen 'CPred' class");
54 }
55 
CPred(const llvm::Init * init)56 CPred::CPred(const llvm::Init *init) : Pred(init) {
57   assert((!def || def->isSubClassOf("CPred")) &&
58          "must be a subclass of Tablegen 'CPred' class");
59 }
60 
61 // Get condition of the C Predicate.
getConditionImpl() const62 std::string CPred::getConditionImpl() const {
63   assert(!isNull() && "null predicate does not have a condition");
64   return std::string(def->getValueAsString("predExpr"));
65 }
66 
CombinedPred(const llvm::Record * record)67 CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
68   assert(def->isSubClassOf("CombinedPred") &&
69          "must be a subclass of Tablegen 'CombinedPred' class");
70 }
71 
CombinedPred(const llvm::Init * init)72 CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
73   assert((!def || def->isSubClassOf("CombinedPred")) &&
74          "must be a subclass of Tablegen 'CombinedPred' class");
75 }
76 
getCombinerDef() const77 const llvm::Record *CombinedPred::getCombinerDef() const {
78   assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
79   return def->getValueAsDef("kind");
80 }
81 
getChildren() const82 const std::vector<llvm::Record *> CombinedPred::getChildren() const {
83   assert(def->getValue("children") &&
84          "CombinedPred must have a value 'children'");
85   return def->getValueAsListOfDefs("children");
86 }
87 
88 namespace {
89 // Kinds of nodes in a logical predicate tree.
90 enum class PredCombinerKind {
91   Leaf,
92   And,
93   Or,
94   Not,
95   SubstLeaves,
96   Concat,
97   // Special kinds that are used in simplification.
98   False,
99   True
100 };
101 
102 // A node in a logical predicate tree.
103 struct PredNode {
104   PredCombinerKind kind;
105   const Pred *predicate;
106   SmallVector<PredNode *, 4> children;
107   std::string expr;
108 
109   // Prefix and suffix are used by ConcatPred.
110   std::string prefix;
111   std::string suffix;
112 };
113 } // end anonymous namespace
114 
115 // Get a predicate tree node kind based on the kind used in the predicate
116 // TableGen record.
getPredCombinerKind(const Pred & pred)117 static PredCombinerKind getPredCombinerKind(const Pred &pred) {
118   if (!pred.isCombined())
119     return PredCombinerKind::Leaf;
120 
121   const auto &combinedPred = static_cast<const CombinedPred &>(pred);
122   return StringSwitch<PredCombinerKind>(
123              combinedPred.getCombinerDef()->getName())
124       .Case("PredCombinerAnd", PredCombinerKind::And)
125       .Case("PredCombinerOr", PredCombinerKind::Or)
126       .Case("PredCombinerNot", PredCombinerKind::Not)
127       .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
128       .Case("PredCombinerConcat", PredCombinerKind::Concat);
129 }
130 
131 namespace {
132 // Substitution<pattern, replacement>.
133 using Subst = std::pair<StringRef, StringRef>;
134 } // end anonymous namespace
135 
136 // Build the predicate tree starting from the top-level predicate, which may
137 // have children, and perform leaf substitutions inplace.  Note that after
138 // substitution, nodes are still pointing to the original TableGen record.
139 // All nodes are created within "allocator".
140 static PredNode *
buildPredicateTree(const Pred & root,llvm::SpecificBumpPtrAllocator<PredNode> & allocator,ArrayRef<Subst> substitutions)141 buildPredicateTree(const Pred &root,
142                    llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
143                    ArrayRef<Subst> substitutions) {
144   auto *rootNode = allocator.Allocate();
145   new (rootNode) PredNode;
146   rootNode->kind = getPredCombinerKind(root);
147   rootNode->predicate = &root;
148   if (!root.isCombined()) {
149     rootNode->expr = root.getCondition();
150     // Apply all parent substitutions from innermost to outermost.
151     for (const auto &subst : llvm::reverse(substitutions)) {
152       auto pos = rootNode->expr.find(std::string(subst.first));
153       while (pos != std::string::npos) {
154         rootNode->expr.replace(pos, subst.first.size(),
155                                std::string(subst.second));
156         // Skip the newly inserted substring, which itself may consider the
157         // pattern to match.
158         pos += subst.second.size();
159         // Find the next possible match position.
160         pos = rootNode->expr.find(std::string(subst.first), pos);
161       }
162     }
163     return rootNode;
164   }
165 
166   // If the current combined predicate is a leaf substitution, append it to the
167   // list before continuing.
168   auto allSubstitutions = llvm::to_vector<4>(substitutions);
169   if (rootNode->kind == PredCombinerKind::SubstLeaves) {
170     const auto &substPred = static_cast<const SubstLeavesPred &>(root);
171     allSubstitutions.push_back(
172         {substPred.getPattern(), substPred.getReplacement()});
173   }
174   // If the current predicate is a ConcatPred, record the prefix and suffix.
175   else if (rootNode->kind == PredCombinerKind::Concat) {
176     const auto &concatPred = static_cast<const ConcatPred &>(root);
177     rootNode->prefix = std::string(concatPred.getPrefix());
178     rootNode->suffix = std::string(concatPred.getSuffix());
179   }
180 
181   // Build child subtrees.
182   auto combined = static_cast<const CombinedPred &>(root);
183   for (const auto *record : combined.getChildren()) {
184     auto childTree =
185         buildPredicateTree(Pred(record), allocator, allSubstitutions);
186     rootNode->children.push_back(childTree);
187   }
188   return rootNode;
189 }
190 
191 // Simplify a predicate tree rooted at "node" using the predicates that are
192 // known to be true(false).  For AND(OR) combined predicates, if any of the
193 // children is known to be false(true), the result is also false(true).
194 // Furthermore, for AND(OR) combined predicates, children that are known to be
195 // true(false) don't have to be checked dynamically.
196 static PredNode *
propagateGroundTruth(PredNode * node,const llvm::SmallPtrSetImpl<Pred * > & knownTruePreds,const llvm::SmallPtrSetImpl<Pred * > & knownFalsePreds)197 propagateGroundTruth(PredNode *node,
198                      const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
199                      const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
200   // If the current predicate is known to be true or false, change the kind of
201   // the node and return immediately.
202   if (knownTruePreds.count(node->predicate) != 0) {
203     node->kind = PredCombinerKind::True;
204     node->children.clear();
205     return node;
206   }
207   if (knownFalsePreds.count(node->predicate) != 0) {
208     node->kind = PredCombinerKind::False;
209     node->children.clear();
210     return node;
211   }
212 
213   // If the current node is a substitution, stop recursion now.
214   // The expressions in the leaves below this node were rewritten, but the nodes
215   // still point to the original predicate records.  While the original
216   // predicate may be known to be true or false, it is not necessarily the case
217   // after rewriting.
218   // TODO: we can support ground truth for rewritten
219   // predicates by either (a) having our own unique'ing of the predicates
220   // instead of relying on TableGen record pointers or (b) taking ground truth
221   // values optionally prefixed with a list of substitutions to apply, e.g.
222   // "predX is true by itself as well as predSubY leaf substitution had been
223   // applied to it".
224   if (node->kind == PredCombinerKind::SubstLeaves) {
225     return node;
226   }
227 
228   // Otherwise, look at child nodes.
229 
230   // Move child nodes into some local variable so that they can be optimized
231   // separately and re-added if necessary.
232   llvm::SmallVector<PredNode *, 4> children;
233   std::swap(node->children, children);
234 
235   for (auto &child : children) {
236     // First, simplify the child.  This maintains the predicate as it was.
237     auto simplifiedChild =
238         propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
239 
240     // Just add the child if we don't know how to simplify the current node.
241     if (node->kind != PredCombinerKind::And &&
242         node->kind != PredCombinerKind::Or) {
243       node->children.push_back(simplifiedChild);
244       continue;
245     }
246 
247     // Second, based on the type define which known values of child predicates
248     // immediately collapse this predicate to a known value, and which others
249     // may be safely ignored.
250     //   OR(..., True, ...) = True
251     //   OR(..., False, ...) = OR(..., ...)
252     //   AND(..., False, ...) = False
253     //   AND(..., True, ...) = AND(..., ...)
254     auto collapseKind = node->kind == PredCombinerKind::And
255                             ? PredCombinerKind::False
256                             : PredCombinerKind::True;
257     auto eraseKind = node->kind == PredCombinerKind::And
258                          ? PredCombinerKind::True
259                          : PredCombinerKind::False;
260     const auto &collapseList =
261         node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
262     const auto &eraseList =
263         node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
264     if (simplifiedChild->kind == collapseKind ||
265         collapseList.count(simplifiedChild->predicate) != 0) {
266       node->kind = collapseKind;
267       node->children.clear();
268       return node;
269     } else if (simplifiedChild->kind == eraseKind ||
270                eraseList.count(simplifiedChild->predicate) != 0) {
271       continue;
272     }
273     node->children.push_back(simplifiedChild);
274   }
275   return node;
276 }
277 
278 // Combine a list of predicate expressions using a binary combiner.  If a list
279 // is empty, return "init".
combineBinary(ArrayRef<std::string> children,std::string combiner,std::string init)280 static std::string combineBinary(ArrayRef<std::string> children,
281                                  std::string combiner, std::string init) {
282   if (children.empty())
283     return init;
284 
285   auto size = children.size();
286   if (size == 1)
287     return children.front();
288 
289   std::string str;
290   llvm::raw_string_ostream os(str);
291   os << '(' << children.front() << ')';
292   for (unsigned i = 1; i < size; ++i) {
293     os << ' ' << combiner << " (" << children[i] << ')';
294   }
295   return os.str();
296 }
297 
298 // Prepend negation to the only condition in the predicate expression list.
combineNot(ArrayRef<std::string> children)299 static std::string combineNot(ArrayRef<std::string> children) {
300   assert(children.size() == 1 && "expected exactly one child predicate of Neg");
301   return (Twine("!(") + children.front() + Twine(')')).str();
302 }
303 
304 // Recursively traverse the predicate tree in depth-first post-order and build
305 // the final expression.
getCombinedCondition(const PredNode & root)306 static std::string getCombinedCondition(const PredNode &root) {
307   // Immediately return for non-combiner predicates that don't have children.
308   if (root.kind == PredCombinerKind::Leaf)
309     return root.expr;
310   if (root.kind == PredCombinerKind::True)
311     return "true";
312   if (root.kind == PredCombinerKind::False)
313     return "false";
314 
315   // Recurse into children.
316   llvm::SmallVector<std::string, 4> childExpressions;
317   childExpressions.reserve(root.children.size());
318   for (const auto &child : root.children)
319     childExpressions.push_back(getCombinedCondition(*child));
320 
321   // Combine the expressions based on the predicate node kind.
322   if (root.kind == PredCombinerKind::And)
323     return combineBinary(childExpressions, "&&", "true");
324   if (root.kind == PredCombinerKind::Or)
325     return combineBinary(childExpressions, "||", "false");
326   if (root.kind == PredCombinerKind::Not)
327     return combineNot(childExpressions);
328   if (root.kind == PredCombinerKind::Concat) {
329     assert(childExpressions.size() == 1 &&
330            "ConcatPred should only have one child");
331     return root.prefix + childExpressions.front() + root.suffix;
332   }
333 
334   // Substitutions were applied before so just ignore them.
335   if (root.kind == PredCombinerKind::SubstLeaves) {
336     assert(childExpressions.size() == 1 &&
337            "substitution predicate must have one child");
338     return childExpressions[0];
339   }
340 
341   llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
342 }
343 
getConditionImpl() const344 std::string CombinedPred::getConditionImpl() const {
345   llvm::SpecificBumpPtrAllocator<PredNode> allocator;
346   auto predicateTree = buildPredicateTree(*this, allocator, {});
347   predicateTree =
348       propagateGroundTruth(predicateTree,
349                            /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
350                            /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
351 
352   return getCombinedCondition(*predicateTree);
353 }
354 
getPattern() const355 StringRef SubstLeavesPred::getPattern() const {
356   return def->getValueAsString("pattern");
357 }
358 
getReplacement() const359 StringRef SubstLeavesPred::getReplacement() const {
360   return def->getValueAsString("replacement");
361 }
362 
getPrefix() const363 StringRef ConcatPred::getPrefix() const {
364   return def->getValueAsString("prefix");
365 }
366 
getSuffix() const367 StringRef ConcatPred::getSuffix() const {
368   return def->getValueAsString("suffix");
369 }
370