1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions for "predicates" used when converting PDL into 10 // a matcher tree. Predicates are composed of three different parts: 11 // 12 // * Positions 13 // - A position refers to a specific location on the input DAG, i.e. an 14 // existing MLIR entity being matched. These can be attributes, operands, 15 // operations, results, and types. Each position also defines a relation to 16 // its parent. For example, the operand `[0] -> 1` has a parent operation 17 // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation 18 // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge 19 // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position 20 // without a parent is `[0]`, which refers to the root operation. 21 // * Questions 22 // - A question refers to a query on a specific positional value. For 23 // example, an operation name question checks the name of an operation 24 // position. 25 // * Answers 26 // - An answer is the expected result of a question. For example, when 27 // matching an operation with the name "foo.op". The question would be an 28 // operation name question, with an expected answer of "foo.op". 29 // 30 //===----------------------------------------------------------------------===// 31 32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 34 35 #include "mlir/IR/MLIRContext.h" 36 #include "mlir/IR/OperationSupport.h" 37 #include "mlir/IR/PatternMatch.h" 38 #include "mlir/IR/Types.h" 39 40 namespace mlir { 41 namespace pdl_to_pdl_interp { 42 namespace Predicates { 43 /// An enumeration of the kinds of predicates. 44 enum Kind : unsigned { 45 /// Positions, ordered by decreasing priority. 46 OperationPos, 47 OperandPos, 48 AttributePos, 49 ResultPos, 50 TypePos, 51 52 // Questions, ordered by dependency and decreasing priority. 53 IsNotNullQuestion, 54 OperationNameQuestion, 55 TypeQuestion, 56 AttributeQuestion, 57 OperandCountQuestion, 58 ResultCountQuestion, 59 EqualToQuestion, 60 ConstraintQuestion, 61 62 // Answers. 63 AttributeAnswer, 64 TrueAnswer, 65 OperationNameAnswer, 66 TypeAnswer, 67 UnsignedAnswer, 68 }; 69 } // end namespace Predicates 70 71 /// Base class for all predicates, used to allow efficient pointer comparison. 72 template <typename ConcreteT, typename BaseT, typename Key, 73 Predicates::Kind Kind> 74 class PredicateBase : public BaseT { 75 public: 76 using KeyTy = Key; 77 using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>; 78 79 template <typename KeyT> PredicateBase(KeyT && key)80 explicit PredicateBase(KeyT &&key) 81 : BaseT(Kind), key(std::forward<KeyT>(key)) {} 82 83 /// Get an instance of this position. 84 template <typename... Args> get(StorageUniquer & uniquer,Args &&...args)85 static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { 86 return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...); 87 } 88 89 /// Construct an instance with the given storage allocator. 90 template <typename KeyT> construct(StorageUniquer::StorageAllocator & alloc,KeyT && key)91 static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, 92 KeyT &&key) { 93 return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key)); 94 } 95 96 /// Utility methods required by the storage allocator. 97 bool operator==(const KeyTy &key) const { return this->key == key; } classof(const BaseT * pred)98 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } 99 100 /// Return the key value of this predicate. getValue()101 const KeyTy &getValue() const { return key; } 102 103 protected: 104 KeyTy key; 105 }; 106 107 /// Base storage for simple predicates that only unique with the kind. 108 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind> 109 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT { 110 public: 111 using Base = PredicateBase<ConcreteT, BaseT, void, Kind>; 112 PredicateBase()113 explicit PredicateBase() : BaseT(Kind) {} 114 get(StorageUniquer & uniquer)115 static ConcreteT *get(StorageUniquer &uniquer) { 116 return uniquer.get<ConcreteT>(); 117 } classof(const BaseT * pred)118 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } 119 }; 120 121 //===----------------------------------------------------------------------===// 122 // Positions 123 //===----------------------------------------------------------------------===// 124 125 struct OperationPosition; 126 127 /// A position describes a value on the input IR on which a predicate may be 128 /// applied, such as an operation or attribute. This enables re-use between 129 /// predicates, and assists generating bytecode and memory management. 130 /// 131 /// Operation positions form the base of other positions, which are formed 132 /// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations 133 /// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd 134 /// child of the root operation. 135 /// 136 /// Positions are linked to their parent position, which describes how to obtain 137 /// a positional value. As a concrete example, getting OperationPosition<[0, 1]> 138 /// would be `root->getOperand(1)->getDefiningOp()`, so its parent is 139 /// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>. 140 class Position : public StorageUniquer::BaseStorage { 141 public: Position(Predicates::Kind kind)142 explicit Position(Predicates::Kind kind) : kind(kind) {} 143 virtual ~Position(); 144 145 /// Returns the base node position. This is an array of indices. 146 virtual ArrayRef<unsigned> getIndex() const = 0; 147 148 /// Returns the parent position. The root operation position has no parent. getParent()149 Position *getParent() const { return parent; } 150 151 /// Returns the kind of this position. getKind()152 Predicates::Kind getKind() const { return kind; } 153 154 protected: 155 /// Link to the parent position. 156 Position *parent = nullptr; 157 158 private: 159 /// The kind of this position. 160 Predicates::Kind kind; 161 }; 162 163 //===----------------------------------------------------------------------===// 164 // AttributePosition 165 166 /// A position describing an attribute of an operation. 167 struct AttributePosition 168 : public PredicateBase<AttributePosition, Position, 169 std::pair<OperationPosition *, Identifier>, 170 Predicates::AttributePos> { 171 explicit AttributePosition(const KeyTy &key); 172 173 /// Returns the index of this position. getIndexAttributePosition174 ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); } 175 176 /// Returns the attribute name of this position. getNameAttributePosition177 Identifier getName() const { return key.second; } 178 }; 179 180 //===----------------------------------------------------------------------===// 181 // OperandPosition 182 183 /// A position describing an operand of an operation. 184 struct OperandPosition 185 : public PredicateBase<OperandPosition, Position, 186 std::pair<OperationPosition *, unsigned>, 187 Predicates::OperandPos> { 188 explicit OperandPosition(const KeyTy &key); 189 190 /// Returns the index of this position. getIndexOperandPosition191 ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); } 192 193 /// Returns the operand number of this position. getOperandNumberOperandPosition194 unsigned getOperandNumber() const { return key.second; } 195 }; 196 197 //===----------------------------------------------------------------------===// 198 // OperationPosition 199 200 /// An operation position describes an operation node in the IR. Other position 201 /// kinds are formed with respect to an operation position. 202 struct OperationPosition 203 : public PredicateBase<OperationPosition, Position, ArrayRef<unsigned>, 204 Predicates::OperationPos> { 205 using Base::Base; 206 207 /// Gets the root position, which is always [0]. getRootOperationPosition208 static OperationPosition *getRoot(StorageUniquer &uniquer) { 209 return get(uniquer, ArrayRef<unsigned>(0)); 210 } 211 /// Gets a node position for the given index. 212 static OperationPosition *get(StorageUniquer &uniquer, 213 ArrayRef<unsigned> index); 214 215 /// Constructs an instance with the given storage allocator. constructOperationPosition216 static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc, 217 ArrayRef<unsigned> key) { 218 return Base::construct(alloc, alloc.copyInto(key)); 219 } 220 221 /// Returns the index of this position. getIndexOperationPosition222 ArrayRef<unsigned> getIndex() const final { return key; } 223 224 /// Returns if this operation position corresponds to the root. isRootOperationPosition225 bool isRoot() const { return key.size() == 1 && key[0] == 0; } 226 }; 227 228 //===----------------------------------------------------------------------===// 229 // ResultPosition 230 231 /// A position describing a result of an operation. 232 struct ResultPosition 233 : public PredicateBase<ResultPosition, Position, 234 std::pair<OperationPosition *, unsigned>, 235 Predicates::ResultPos> { ResultPositionResultPosition236 explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } 237 238 /// Returns the index of this position. getIndexResultPosition239 ArrayRef<unsigned> getIndex() const final { return key.first->getIndex(); } 240 241 /// Returns the result number of this position. getResultNumberResultPosition242 unsigned getResultNumber() const { return key.second; } 243 }; 244 245 //===----------------------------------------------------------------------===// 246 // TypePosition 247 248 /// A position describing the result type of an entity, i.e. an Attribute, 249 /// Operand, Result, etc. 250 struct TypePosition : public PredicateBase<TypePosition, Position, Position *, 251 Predicates::TypePos> { TypePositionTypePosition252 explicit TypePosition(const KeyTy &key) : Base(key) { 253 assert((isa<AttributePosition>(key) || isa<OperandPosition>(key) || 254 isa<ResultPosition>(key)) && 255 "expected parent to be an attribute, operand, or result"); 256 parent = key; 257 } 258 259 /// Returns the index of this position. getIndexTypePosition260 ArrayRef<unsigned> getIndex() const final { return key->getIndex(); } 261 }; 262 263 //===----------------------------------------------------------------------===// 264 // Qualifiers 265 //===----------------------------------------------------------------------===// 266 267 /// An ordinal predicate consists of a "Question" and a set of acceptable 268 /// "Answers" (later converted to ordinal values). A predicate will query some 269 /// property of a positional value and decide what to do based on the result. 270 /// 271 /// This makes top-level predicate representations ordinal (SwitchOp). Later, 272 /// predicates that end up with only one acceptable answer (including all 273 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the 274 /// matcher. 275 /// 276 /// For simplicity, both are represented as "qualifiers", with a base kind and 277 /// perhaps additional properties. For example, all OperationName predicates ask 278 /// the same question, but GenericConstraint predicates may ask different ones. 279 class Qualifier : public StorageUniquer::BaseStorage { 280 public: Qualifier(Predicates::Kind kind)281 explicit Qualifier(Predicates::Kind kind) : kind(kind) {} 282 283 /// Returns the kind of this qualifier. getKind()284 Predicates::Kind getKind() const { return kind; } 285 286 private: 287 /// The kind of this position. 288 Predicates::Kind kind; 289 }; 290 291 //===----------------------------------------------------------------------===// 292 // Answers 293 294 /// An Answer representing an `Attribute` value. 295 struct AttributeAnswer 296 : public PredicateBase<AttributeAnswer, Qualifier, Attribute, 297 Predicates::AttributeAnswer> { 298 using Base::Base; 299 }; 300 301 /// An Answer representing an `OperationName` value. 302 struct OperationNameAnswer 303 : public PredicateBase<OperationNameAnswer, Qualifier, OperationName, 304 Predicates::OperationNameAnswer> { 305 using Base::Base; 306 }; 307 308 /// An Answer representing a boolean `true` value. 309 struct TrueAnswer 310 : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> { 311 using Base::Base; 312 }; 313 314 /// An Answer representing a `Type` value. 315 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Type, 316 Predicates::TypeAnswer> { 317 using Base::Base; 318 }; 319 320 /// An Answer representing an unsigned value. 321 struct UnsignedAnswer 322 : public PredicateBase<UnsignedAnswer, Qualifier, unsigned, 323 Predicates::UnsignedAnswer> { 324 using Base::Base; 325 }; 326 327 //===----------------------------------------------------------------------===// 328 // Questions 329 330 /// Compare an `Attribute` to a constant value. 331 struct AttributeQuestion 332 : public PredicateBase<AttributeQuestion, Qualifier, void, 333 Predicates::AttributeQuestion> {}; 334 335 /// Apply a parameterized constraint to multiple position values. 336 struct ConstraintQuestion 337 : public PredicateBase< 338 ConstraintQuestion, Qualifier, 339 std::tuple<StringRef, ArrayRef<Position *>, Attribute>, 340 Predicates::ConstraintQuestion> { 341 using Base::Base; 342 343 /// Construct an instance with the given storage allocator. constructConstraintQuestion344 static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, 345 KeyTy key) { 346 return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), 347 alloc.copyInto(std::get<1>(key)), 348 std::get<2>(key)}); 349 } 350 }; 351 352 /// Compare the equality of two values. 353 struct EqualToQuestion 354 : public PredicateBase<EqualToQuestion, Qualifier, Position *, 355 Predicates::EqualToQuestion> { 356 using Base::Base; 357 }; 358 359 /// Compare a positional value with null, i.e. check if it exists. 360 struct IsNotNullQuestion 361 : public PredicateBase<IsNotNullQuestion, Qualifier, void, 362 Predicates::IsNotNullQuestion> {}; 363 364 /// Compare the number of operands of an operation with a known value. 365 struct OperandCountQuestion 366 : public PredicateBase<OperandCountQuestion, Qualifier, void, 367 Predicates::OperandCountQuestion> {}; 368 369 /// Compare the name of an operation with a known value. 370 struct OperationNameQuestion 371 : public PredicateBase<OperationNameQuestion, Qualifier, void, 372 Predicates::OperationNameQuestion> {}; 373 374 /// Compare the number of results of an operation with a known value. 375 struct ResultCountQuestion 376 : public PredicateBase<ResultCountQuestion, Qualifier, void, 377 Predicates::ResultCountQuestion> {}; 378 379 /// Compare the type of an attribute or value with a known type. 380 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void, 381 Predicates::TypeQuestion> {}; 382 383 //===----------------------------------------------------------------------===// 384 // PredicateUniquer 385 //===----------------------------------------------------------------------===// 386 387 /// This class provides a storage uniquer that is used to allocate predicate 388 /// instances. 389 class PredicateUniquer : public StorageUniquer { 390 public: PredicateUniquer()391 PredicateUniquer() { 392 // Register the types of Positions with the uniquer. 393 registerParametricStorageType<AttributePosition>(); 394 registerParametricStorageType<OperandPosition>(); 395 registerParametricStorageType<OperationPosition>(); 396 registerParametricStorageType<ResultPosition>(); 397 registerParametricStorageType<TypePosition>(); 398 399 // Register the types of Questions with the uniquer. 400 registerParametricStorageType<AttributeAnswer>(); 401 registerParametricStorageType<OperationNameAnswer>(); 402 registerParametricStorageType<TypeAnswer>(); 403 registerParametricStorageType<UnsignedAnswer>(); 404 registerSingletonStorageType<TrueAnswer>(); 405 406 // Register the types of Answers with the uniquer. 407 registerParametricStorageType<ConstraintQuestion>(); 408 registerParametricStorageType<EqualToQuestion>(); 409 registerSingletonStorageType<AttributeQuestion>(); 410 registerSingletonStorageType<IsNotNullQuestion>(); 411 registerSingletonStorageType<OperandCountQuestion>(); 412 registerSingletonStorageType<OperationNameQuestion>(); 413 registerSingletonStorageType<ResultCountQuestion>(); 414 registerSingletonStorageType<TypeQuestion>(); 415 } 416 }; 417 418 //===----------------------------------------------------------------------===// 419 // PredicateBuilder 420 //===----------------------------------------------------------------------===// 421 422 /// This class provides utilties for constructing predicates. 423 class PredicateBuilder { 424 public: PredicateBuilder(PredicateUniquer & uniquer,MLIRContext * ctx)425 PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) 426 : uniquer(uniquer), ctx(ctx) {} 427 428 //===--------------------------------------------------------------------===// 429 // Positions 430 //===--------------------------------------------------------------------===// 431 432 /// Returns the root operation position. getRoot()433 Position *getRoot() { return OperationPosition::getRoot(uniquer); } 434 435 /// Returns the parent position defining the value held by the given operand. getParent(OperandPosition * p)436 Position *getParent(OperandPosition *p) { 437 std::vector<unsigned> index = p->getIndex(); 438 index.push_back(p->getOperandNumber()); 439 return OperationPosition::get(uniquer, index); 440 } 441 442 /// Returns an attribute position for an attribute of the given operation. getAttribute(OperationPosition * p,StringRef name)443 Position *getAttribute(OperationPosition *p, StringRef name) { 444 return AttributePosition::get(uniquer, p, Identifier::get(name, ctx)); 445 } 446 447 /// Returns an operand position for an operand of the given operation. getOperand(OperationPosition * p,unsigned operand)448 Position *getOperand(OperationPosition *p, unsigned operand) { 449 return OperandPosition::get(uniquer, p, operand); 450 } 451 452 /// Returns a result position for a result of the given operation. getResult(OperationPosition * p,unsigned result)453 Position *getResult(OperationPosition *p, unsigned result) { 454 return ResultPosition::get(uniquer, p, result); 455 } 456 457 /// Returns a type position for the given entity. getType(Position * p)458 Position *getType(Position *p) { return TypePosition::get(uniquer, p); } 459 460 //===--------------------------------------------------------------------===// 461 // Qualifiers 462 //===--------------------------------------------------------------------===// 463 464 /// An ordinal predicate consists of a "Question" and a set of acceptable 465 /// "Answers" (later converted to ordinal values). A predicate will query some 466 /// property of a positional value and decide what to do based on the result. 467 using Predicate = std::pair<Qualifier *, Qualifier *>; 468 469 /// Create a predicate comparing an attribute to a known value. getAttributeConstraint(Attribute attr)470 Predicate getAttributeConstraint(Attribute attr) { 471 return {AttributeQuestion::get(uniquer), 472 AttributeAnswer::get(uniquer, attr)}; 473 } 474 475 /// Create a predicate comparing two values. getEqualTo(Position * pos)476 Predicate getEqualTo(Position *pos) { 477 return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; 478 } 479 480 /// Create a predicate that applies a generic constraint. getConstraint(StringRef name,ArrayRef<Position * > pos,Attribute params)481 Predicate getConstraint(StringRef name, ArrayRef<Position *> pos, 482 Attribute params) { 483 return { 484 ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)), 485 TrueAnswer::get(uniquer)}; 486 } 487 488 /// Create a predicate comparing a value with null. getIsNotNull()489 Predicate getIsNotNull() { 490 return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; 491 } 492 493 /// Create a predicate comparing the number of operands of an operation to a 494 /// known value. getOperandCount(unsigned count)495 Predicate getOperandCount(unsigned count) { 496 return {OperandCountQuestion::get(uniquer), 497 UnsignedAnswer::get(uniquer, count)}; 498 } 499 500 /// Create a predicate comparing the name of an operation to a known value. getOperationName(StringRef name)501 Predicate getOperationName(StringRef name) { 502 return {OperationNameQuestion::get(uniquer), 503 OperationNameAnswer::get(uniquer, OperationName(name, ctx))}; 504 } 505 506 /// Create a predicate comparing the number of results of an operation to a 507 /// known value. getResultCount(unsigned count)508 Predicate getResultCount(unsigned count) { 509 return {ResultCountQuestion::get(uniquer), 510 UnsignedAnswer::get(uniquer, count)}; 511 } 512 513 /// Create a predicate comparing the type of an attribute or value to a known 514 /// type. getTypeConstraint(Type type)515 Predicate getTypeConstraint(Type type) { 516 return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)}; 517 } 518 519 private: 520 /// The uniquer used when allocating predicate nodes. 521 PredicateUniquer &uniquer; 522 523 /// The current MLIR context. 524 MLIRContext *ctx; 525 }; 526 527 } // end namespace pdl_to_pdl_interp 528 } // end namespace mlir 529 530 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 531