1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares a generic pass for converting between MLIR dialects. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_ 14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ 15 16 #include "mlir/Rewrite/FrozenRewritePatternList.h" 17 #include "llvm/ADT/MapVector.h" 18 #include "llvm/ADT/StringMap.h" 19 20 namespace mlir { 21 22 // Forward declarations. 23 class Block; 24 class ConversionPatternRewriter; 25 class FuncOp; 26 class MLIRContext; 27 class Operation; 28 class Type; 29 class Value; 30 31 //===----------------------------------------------------------------------===// 32 // Type Conversion 33 //===----------------------------------------------------------------------===// 34 35 /// Type conversion class. Specific conversions and materializations can be 36 /// registered using addConversion and addMaterialization, respectively. 37 class TypeConverter { 38 public: 39 /// This class provides all of the information necessary to convert a type 40 /// signature. 41 class SignatureConversion { 42 public: SignatureConversion(unsigned numOrigInputs)43 SignatureConversion(unsigned numOrigInputs) 44 : remappedInputs(numOrigInputs) {} 45 46 /// This struct represents a range of new types or a single value that 47 /// remaps an existing signature input. 48 struct InputMapping { 49 size_t inputNo, size; 50 Value replacementValue; 51 }; 52 53 /// Return the argument types for the new signature. getConvertedTypes()54 ArrayRef<Type> getConvertedTypes() const { return argTypes; } 55 56 /// Get the input mapping for the given argument. getInputMapping(unsigned input)57 Optional<InputMapping> getInputMapping(unsigned input) const { 58 return remappedInputs[input]; 59 } 60 61 //===------------------------------------------------------------------===// 62 // Conversion Hooks 63 //===------------------------------------------------------------------===// 64 65 /// Remap an input of the original signature with a new set of types. The 66 /// new types are appended to the new signature conversion. 67 void addInputs(unsigned origInputNo, ArrayRef<Type> types); 68 69 /// Append new input types to the signature conversion, this should only be 70 /// used if the new types are not intended to remap an existing input. 71 void addInputs(ArrayRef<Type> types); 72 73 /// Remap an input of the original signature to another `replacement` 74 /// value. This drops the original argument. 75 void remapInput(unsigned origInputNo, Value replacement); 76 77 private: 78 /// Remap an input of the original signature with a range of types in the 79 /// new signature. 80 void remapInput(unsigned origInputNo, unsigned newInputNo, 81 unsigned newInputCount = 1); 82 83 /// The remapping information for each of the original arguments. 84 SmallVector<Optional<InputMapping>, 4> remappedInputs; 85 86 /// The set of new argument types. 87 SmallVector<Type, 4> argTypes; 88 }; 89 90 /// Register a conversion function. A conversion function must be convertible 91 /// to any of the following forms(where `T` is a class derived from `Type`: 92 /// * Optional<Type>(T) 93 /// - This form represents a 1-1 type conversion. It should return nullptr 94 /// or `llvm::None` to signify failure. If `llvm::None` is returned, the 95 /// converter is allowed to try another conversion function to perform 96 /// the conversion. 97 /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &) 98 /// - This form represents a 1-N type conversion. It should return 99 /// `failure` or `llvm::None` to signify a failed conversion. If the new 100 /// set of types is empty, the type is removed and any usages of the 101 /// existing value are expected to be removed during conversion. If 102 /// `llvm::None` is returned, the converter is allowed to try another 103 /// conversion function to perform the conversion. 104 /// Note: When attempting to convert a type, e.g. via 'convertType', the 105 /// mostly recently added conversions will be invoked first. 106 template <typename FnT, 107 typename T = typename llvm::function_traits<FnT>::template arg_t<0>> addConversion(FnT && callback)108 void addConversion(FnT &&callback) { 109 registerConversion(wrapCallback<T>(std::forward<FnT>(callback))); 110 } 111 112 /// Register a materialization function, which must be convertible to the 113 /// following form: 114 /// `Optional<Value>(OpBuilder &, T, ValueRange, Location)`, 115 /// where `T` is any subclass of `Type`. This function is responsible for 116 /// creating an operation, using the OpBuilder and Location provided, that 117 /// "casts" a range of values into a single value of the given type `T`. It 118 /// must return a Value of the converted type on success, an `llvm::None` if 119 /// it failed but other materialization can be attempted, and `nullptr` on 120 /// unrecoverable failure. It will only be called for (sub)types of `T`. 121 /// Materialization functions must be provided when a type conversion 122 /// results in more than one type, or if a type conversion may persist after 123 /// the conversion has finished. 124 /// 125 /// This method registers a materialization that will be called when 126 /// converting an illegal block argument type, to a legal type. 127 template <typename FnT, 128 typename T = typename llvm::function_traits<FnT>::template arg_t<1>> addArgumentMaterialization(FnT && callback)129 void addArgumentMaterialization(FnT &&callback) { 130 argumentMaterializations.emplace_back( 131 wrapMaterialization<T>(std::forward<FnT>(callback))); 132 } 133 /// This method registers a materialization that will be called when 134 /// converting a legal type to an illegal source type. This is used when 135 /// conversions to an illegal type must persist beyond the main conversion. 136 template <typename FnT, 137 typename T = typename llvm::function_traits<FnT>::template arg_t<1>> addSourceMaterialization(FnT && callback)138 void addSourceMaterialization(FnT &&callback) { 139 sourceMaterializations.emplace_back( 140 wrapMaterialization<T>(std::forward<FnT>(callback))); 141 } 142 /// This method registers a materialization that will be called when 143 /// converting type from an illegal, or source, type to a legal type. 144 template <typename FnT, 145 typename T = typename llvm::function_traits<FnT>::template arg_t<1>> addTargetMaterialization(FnT && callback)146 void addTargetMaterialization(FnT &&callback) { 147 targetMaterializations.emplace_back( 148 wrapMaterialization<T>(std::forward<FnT>(callback))); 149 } 150 151 /// Convert the given type. This function should return failure if no valid 152 /// conversion exists, success otherwise. If the new set of types is empty, 153 /// the type is removed and any usages of the existing value are expected to 154 /// be removed during conversion. 155 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results); 156 157 /// This hook simplifies defining 1-1 type conversions. This function returns 158 /// the type to convert to on success, and a null type on failure. 159 Type convertType(Type t); 160 161 /// Convert the given set of types, filling 'results' as necessary. This 162 /// returns failure if the conversion of any of the types fails, success 163 /// otherwise. 164 LogicalResult convertTypes(ArrayRef<Type> types, 165 SmallVectorImpl<Type> &results); 166 167 /// Return true if the given type is legal for this type converter, i.e. the 168 /// type converts to itself. 169 bool isLegal(Type type); 170 /// Return true if all of the given types are legal for this type converter. 171 template <typename RangeT> 172 std::enable_if_t<!std::is_convertible<RangeT, Type>::value && 173 !std::is_convertible<RangeT, Operation *>::value, 174 bool> isLegal(RangeT && range)175 isLegal(RangeT &&range) { 176 return llvm::all_of(range, [this](Type type) { return isLegal(type); }); 177 } 178 /// Return true if the given operation has legal operand and result types. 179 bool isLegal(Operation *op); 180 181 /// Return true if the types of block arguments within the region are legal. 182 bool isLegal(Region *region); 183 184 /// Return true if the inputs and outputs of the given function type are 185 /// legal. 186 bool isSignatureLegal(FunctionType ty); 187 188 /// This method allows for converting a specific argument of a signature. It 189 /// takes as inputs the original argument input number, type. 190 /// On success, it populates 'result' with any new mappings. 191 LogicalResult convertSignatureArg(unsigned inputNo, Type type, 192 SignatureConversion &result); 193 LogicalResult convertSignatureArgs(TypeRange types, 194 SignatureConversion &result, 195 unsigned origInputOffset = 0); 196 197 /// This function converts the type signature of the given block, by invoking 198 /// 'convertSignatureArg' for each argument. This function should return a 199 /// valid conversion for the signature on success, None otherwise. 200 Optional<SignatureConversion> convertBlockSignature(Block *block); 201 202 /// Materialize a conversion from a set of types into one result type by 203 /// generating a cast sequence of some kind. See the respective 204 /// `add*Materialization` for more information on the context for these 205 /// methods. materializeArgumentConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)206 Value materializeArgumentConversion(OpBuilder &builder, Location loc, 207 Type resultType, ValueRange inputs) { 208 return materializeConversion(argumentMaterializations, builder, loc, 209 resultType, inputs); 210 } materializeSourceConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)211 Value materializeSourceConversion(OpBuilder &builder, Location loc, 212 Type resultType, ValueRange inputs) { 213 return materializeConversion(sourceMaterializations, builder, loc, 214 resultType, inputs); 215 } materializeTargetConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)216 Value materializeTargetConversion(OpBuilder &builder, Location loc, 217 Type resultType, ValueRange inputs) { 218 return materializeConversion(targetMaterializations, builder, loc, 219 resultType, inputs); 220 } 221 222 private: 223 /// The signature of the callback used to convert a type. If the new set of 224 /// types is empty, the type is removed and any usages of the existing value 225 /// are expected to be removed during conversion. 226 using ConversionCallbackFn = 227 std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>; 228 229 /// The signature of the callback used to materialize a conversion. 230 using MaterializationCallbackFn = 231 std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>; 232 233 /// Attempt to materialize a conversion using one of the provided 234 /// materialization functions. 235 Value materializeConversion( 236 MutableArrayRef<MaterializationCallbackFn> materializations, 237 OpBuilder &builder, Location loc, Type resultType, ValueRange inputs); 238 239 /// Generate a wrapper for the given callback. This allows for accepting 240 /// different callback forms, that all compose into a single version. 241 /// With callback of form: `Optional<Type>(T)` 242 template <typename T, typename FnT> 243 std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn> wrapCallback(FnT && callback)244 wrapCallback(FnT &&callback) { 245 return wrapCallback<T>([callback = std::forward<FnT>(callback)]( 246 T type, SmallVectorImpl<Type> &results) { 247 if (Optional<Type> resultOpt = callback(type)) { 248 bool wasSuccess = static_cast<bool>(resultOpt.getValue()); 249 if (wasSuccess) 250 results.push_back(resultOpt.getValue()); 251 return Optional<LogicalResult>(success(wasSuccess)); 252 } 253 return Optional<LogicalResult>(); 254 }); 255 } 256 /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)` 257 template <typename T, typename FnT> 258 std::enable_if_t<!llvm::is_invocable<FnT, T>::value, ConversionCallbackFn> wrapCallback(FnT && callback)259 wrapCallback(FnT &&callback) { 260 return [callback = std::forward<FnT>(callback)]( 261 Type type, 262 SmallVectorImpl<Type> &results) -> Optional<LogicalResult> { 263 T derivedType = type.dyn_cast<T>(); 264 if (!derivedType) 265 return llvm::None; 266 return callback(derivedType, results); 267 }; 268 } 269 270 /// Register a type conversion. registerConversion(ConversionCallbackFn callback)271 void registerConversion(ConversionCallbackFn callback) { 272 conversions.emplace_back(std::move(callback)); 273 cachedDirectConversions.clear(); 274 cachedMultiConversions.clear(); 275 } 276 277 /// Generate a wrapper for the given materialization callback. The callback 278 /// may take any subclass of `Type` and the wrapper will check for the target 279 /// type to be of the expected class before calling the callback. 280 template <typename T, typename FnT> wrapMaterialization(FnT && callback)281 MaterializationCallbackFn wrapMaterialization(FnT &&callback) { 282 return [callback = std::forward<FnT>(callback)]( 283 OpBuilder &builder, Type resultType, ValueRange inputs, 284 Location loc) -> Optional<Value> { 285 if (T derivedType = resultType.dyn_cast<T>()) 286 return callback(builder, derivedType, inputs, loc); 287 return llvm::None; 288 }; 289 } 290 291 /// The set of registered conversion functions. 292 SmallVector<ConversionCallbackFn, 4> conversions; 293 294 /// The list of registered materialization functions. 295 SmallVector<MaterializationCallbackFn, 2> argumentMaterializations; 296 SmallVector<MaterializationCallbackFn, 2> sourceMaterializations; 297 SmallVector<MaterializationCallbackFn, 2> targetMaterializations; 298 299 /// A set of cached conversions to avoid recomputing in the common case. 300 /// Direct 1-1 conversions are the most common, so this cache stores the 301 /// successful 1-1 conversions as well as all failed conversions. 302 DenseMap<Type, Type> cachedDirectConversions; 303 /// This cache stores the successful 1->N conversions, where N != 1. 304 DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions; 305 }; 306 307 //===----------------------------------------------------------------------===// 308 // Conversion Patterns 309 //===----------------------------------------------------------------------===// 310 311 /// Base class for the conversion patterns. This pattern class enables type 312 /// conversions, and other uses specific to the conversion framework. As such, 313 /// patterns of this type can only be used with the 'apply*' methods below. 314 class ConversionPattern : public RewritePattern { 315 public: 316 /// Hook for derived classes to implement rewriting. `op` is the (first) 317 /// operation matched by the pattern, `operands` is a list of the rewritten 318 /// operand values that are passed to `op`, `rewriter` can be used to emit the 319 /// new operations. This function should not fail. If some specific cases of 320 /// the operation are not supported, these cases should not be matched. rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)321 virtual void rewrite(Operation *op, ArrayRef<Value> operands, 322 ConversionPatternRewriter &rewriter) const { 323 llvm_unreachable("unimplemented rewrite"); 324 } 325 326 /// Hook for derived classes to implement combined matching and rewriting. 327 virtual LogicalResult matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)328 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 329 ConversionPatternRewriter &rewriter) const { 330 if (failed(match(op))) 331 return failure(); 332 rewrite(op, operands, rewriter); 333 return success(); 334 } 335 336 /// Attempt to match and rewrite the IR root at the specified operation. 337 LogicalResult matchAndRewrite(Operation *op, 338 PatternRewriter &rewriter) const final; 339 340 /// Return the type converter held by this pattern, or nullptr if the pattern 341 /// does not require type conversion. getTypeConverter()342 TypeConverter *getTypeConverter() const { return typeConverter; } 343 344 protected: 345 /// See `RewritePattern::RewritePattern` for information on the other 346 /// available constructors. 347 using RewritePattern::RewritePattern; 348 /// Construct a conversion pattern that matches an operation with the given 349 /// root name. This constructor allows for providing a type converter to use 350 /// within the pattern. ConversionPattern(StringRef rootName,PatternBenefit benefit,TypeConverter & typeConverter,MLIRContext * ctx)351 ConversionPattern(StringRef rootName, PatternBenefit benefit, 352 TypeConverter &typeConverter, MLIRContext *ctx) 353 : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} 354 /// Construct a conversion pattern that matches any operation type. This 355 /// constructor allows for providing a type converter to use within the 356 /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" 357 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should 358 /// always be supplied here. ConversionPattern(PatternBenefit benefit,TypeConverter & typeConverter,MatchAnyOpTypeTag tag)359 ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, 360 MatchAnyOpTypeTag tag) 361 : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} 362 363 protected: 364 /// An optional type converter for use by this pattern. 365 TypeConverter *typeConverter = nullptr; 366 367 private: 368 using RewritePattern::rewrite; 369 }; 370 371 /// OpConversionPattern is a wrapper around ConversionPattern that allows for 372 /// matching and rewriting against an instance of a derived operation class as 373 /// opposed to a raw Operation. 374 template <typename SourceOp> 375 struct OpConversionPattern : public ConversionPattern { 376 OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) ConversionPatternOpConversionPattern377 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} 378 OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, 379 PatternBenefit benefit = 1) ConversionPatternOpConversionPattern380 : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter, 381 context) {} 382 383 /// Wrappers around the ConversionPattern methods that pass the derived op 384 /// type. rewriteOpConversionPattern385 void rewrite(Operation *op, ArrayRef<Value> operands, 386 ConversionPatternRewriter &rewriter) const final { 387 rewrite(cast<SourceOp>(op), operands, rewriter); 388 } 389 LogicalResult matchAndRewriteOpConversionPattern390 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 391 ConversionPatternRewriter &rewriter) const final { 392 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); 393 } 394 395 // TODO: Use OperandAdaptor when it supports access to unnamed operands. 396 397 /// Rewrite and Match methods that operate on the SourceOp type. These must be 398 /// overridden by the derived pattern class. rewriteOpConversionPattern399 virtual void rewrite(SourceOp op, ArrayRef<Value> operands, 400 ConversionPatternRewriter &rewriter) const { 401 llvm_unreachable("must override matchAndRewrite or a rewrite method"); 402 } 403 404 virtual LogicalResult matchAndRewriteOpConversionPattern405 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 406 ConversionPatternRewriter &rewriter) const { 407 if (failed(match(op))) 408 return failure(); 409 rewrite(op, operands, rewriter); 410 return success(); 411 } 412 413 private: 414 using ConversionPattern::matchAndRewrite; 415 }; 416 417 /// Add a pattern to the given pattern list to convert the signature of a FuncOp 418 /// with the given type converter. 419 void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns, 420 MLIRContext *ctx, 421 TypeConverter &converter); 422 423 //===----------------------------------------------------------------------===// 424 // Conversion PatternRewriter 425 //===----------------------------------------------------------------------===// 426 427 namespace detail { 428 struct ConversionPatternRewriterImpl; 429 } // end namespace detail 430 431 /// This class implements a pattern rewriter for use with ConversionPatterns. It 432 /// extends the base PatternRewriter and provides special conversion specific 433 /// hooks. 434 class ConversionPatternRewriter final : public PatternRewriter { 435 public: 436 ConversionPatternRewriter(MLIRContext *ctx); 437 ~ConversionPatternRewriter() override; 438 439 /// Apply a signature conversion to the entry block of the given region. This 440 /// replaces the entry block with a new block containing the updated 441 /// signature. The new entry block to the region is returned for convenience. 442 Block * 443 applySignatureConversion(Region *region, 444 TypeConverter::SignatureConversion &conversion); 445 446 /// Convert the types of block arguments within the given region. This 447 /// replaces each block with a new block containing the updated signature. The 448 /// entry block may have a special conversion if `entryConversion` is 449 /// provided. On success, the new entry block to the region is returned for 450 /// convenience. Otherwise, failure is returned. 451 FailureOr<Block *> convertRegionTypes( 452 Region *region, TypeConverter &converter, 453 TypeConverter::SignatureConversion *entryConversion = nullptr); 454 455 /// Replace all the uses of the block argument `from` with value `to`. 456 void replaceUsesOfBlockArgument(BlockArgument from, Value to); 457 458 /// Return the converted value that replaces 'key'. Return 'key' if there is 459 /// no such a converted value. 460 Value getRemappedValue(Value key); 461 462 //===--------------------------------------------------------------------===// 463 // PatternRewriter Hooks 464 //===--------------------------------------------------------------------===// 465 466 /// PatternRewriter hook for replacing the results of an operation. 467 void replaceOp(Operation *op, ValueRange newValues) override; 468 using PatternRewriter::replaceOp; 469 470 /// PatternRewriter hook for erasing a dead operation. The uses of this 471 /// operation *must* be made dead by the end of the conversion process, 472 /// otherwise an assert will be issued. 473 void eraseOp(Operation *op) override; 474 475 /// PatternRewriter hook for erase all operations in a block. This is not yet 476 /// implemented for dialect conversion. 477 void eraseBlock(Block *block) override; 478 479 /// PatternRewriter hook creating a new block. 480 void notifyBlockCreated(Block *block) override; 481 482 /// PatternRewriter hook for splitting a block into two parts. 483 Block *splitBlock(Block *block, Block::iterator before) override; 484 485 /// PatternRewriter hook for merging a block into another. 486 void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override; 487 488 /// PatternRewriter hook for moving blocks out of a region. 489 void inlineRegionBefore(Region ®ion, Region &parent, 490 Region::iterator before) override; 491 using PatternRewriter::inlineRegionBefore; 492 493 /// PatternRewriter hook for cloning blocks of one region into another. The 494 /// given region to clone *must* not have been modified as part of conversion 495 /// yet, i.e. it must be within an operation that is either in the process of 496 /// conversion, or has not yet been converted. 497 void cloneRegionBefore(Region ®ion, Region &parent, 498 Region::iterator before, 499 BlockAndValueMapping &mapping) override; 500 using PatternRewriter::cloneRegionBefore; 501 502 /// PatternRewriter hook for inserting a new operation. 503 void notifyOperationInserted(Operation *op) override; 504 505 /// PatternRewriter hook for updating the root operation in-place. 506 /// Note: These methods only track updates to the top-level operation itself, 507 /// and not nested regions. Updates to regions will still require notification 508 /// through other more specific hooks above. 509 void startRootUpdate(Operation *op) override; 510 511 /// PatternRewriter hook for updating the root operation in-place. 512 void finalizeRootUpdate(Operation *op) override; 513 514 /// PatternRewriter hook for updating the root operation in-place. 515 void cancelRootUpdate(Operation *op) override; 516 517 /// PatternRewriter hook for notifying match failure reasons. 518 LogicalResult 519 notifyMatchFailure(Operation *op, 520 function_ref<void(Diagnostic &)> reasonCallback) override; 521 using PatternRewriter::notifyMatchFailure; 522 523 /// Return a reference to the internal implementation. 524 detail::ConversionPatternRewriterImpl &getImpl(); 525 526 private: 527 std::unique_ptr<detail::ConversionPatternRewriterImpl> impl; 528 }; 529 530 //===----------------------------------------------------------------------===// 531 // ConversionTarget 532 //===----------------------------------------------------------------------===// 533 534 /// This class describes a specific conversion target. 535 class ConversionTarget { 536 public: 537 /// This enumeration corresponds to the specific action to take when 538 /// considering an operation legal for this conversion target. 539 enum class LegalizationAction { 540 /// The target supports this operation. 541 Legal, 542 543 /// This operation has dynamic legalization constraints that must be checked 544 /// by the target. 545 Dynamic, 546 547 /// The target explicitly does not support this operation. 548 Illegal, 549 }; 550 551 /// A structure containing additional information describing a specific legal 552 /// operation instance. 553 struct LegalOpDetails { 554 /// A flag that indicates if this operation is 'recursively' legal. This 555 /// means that if an operation is legal, either statically or dynamically, 556 /// all of the operations nested within are also considered legal. 557 bool isRecursivelyLegal = false; 558 }; 559 560 /// The signature of the callback used to determine if an operation is 561 /// dynamically legal on the target. 562 using DynamicLegalityCallbackFn = std::function<bool(Operation *)>; 563 ConversionTarget(MLIRContext & ctx)564 ConversionTarget(MLIRContext &ctx) 565 : unknownOpsDynamicallyLegal(false), ctx(ctx) {} 566 virtual ~ConversionTarget() = default; 567 568 //===--------------------------------------------------------------------===// 569 // Legality Registration 570 //===--------------------------------------------------------------------===// 571 572 /// Register a legality action for the given operation. 573 void setOpAction(OperationName op, LegalizationAction action); setOpAction(LegalizationAction action)574 template <typename OpT> void setOpAction(LegalizationAction action) { 575 setOpAction(OperationName(OpT::getOperationName(), &ctx), action); 576 } 577 578 /// Register the given operations as legal. addLegalOp()579 template <typename OpT> void addLegalOp() { 580 setOpAction<OpT>(LegalizationAction::Legal); 581 } addLegalOp()582 template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() { 583 addLegalOp<OpT>(); 584 addLegalOp<OpT2, OpTs...>(); 585 } 586 587 /// Register the given operation as dynamically legal, i.e. requiring custom 588 /// handling by the target via 'isDynamicallyLegal'. addDynamicallyLegalOp()589 template <typename OpT> void addDynamicallyLegalOp() { 590 setOpAction<OpT>(LegalizationAction::Dynamic); 591 } 592 template <typename OpT, typename OpT2, typename... OpTs> addDynamicallyLegalOp()593 void addDynamicallyLegalOp() { 594 addDynamicallyLegalOp<OpT>(); 595 addDynamicallyLegalOp<OpT2, OpTs...>(); 596 } 597 598 /// Register the given operation as dynamically legal and set the dynamic 599 /// legalization callback to the one provided. 600 template <typename OpT> addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)601 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { 602 OperationName opName(OpT::getOperationName(), &ctx); 603 setOpAction(opName, LegalizationAction::Dynamic); 604 setLegalityCallback(opName, callback); 605 } 606 template <typename OpT, typename OpT2, typename... OpTs> addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)607 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { 608 addDynamicallyLegalOp<OpT>(callback); 609 addDynamicallyLegalOp<OpT2, OpTs...>(callback); 610 } 611 template <typename OpT, class Callable> 612 typename std::enable_if< 613 !llvm::is_invocable<Callable, Operation *>::value>::type addDynamicallyLegalOp(Callable && callback)614 addDynamicallyLegalOp(Callable &&callback) { 615 addDynamicallyLegalOp<OpT>( 616 [=](Operation *op) { return callback(cast<OpT>(op)); }); 617 } 618 619 /// Register the given operation as illegal, i.e. this operation is known to 620 /// not be supported by this target. addIllegalOp()621 template <typename OpT> void addIllegalOp() { 622 setOpAction<OpT>(LegalizationAction::Illegal); 623 } addIllegalOp()624 template <typename OpT, typename OpT2, typename... OpTs> void addIllegalOp() { 625 addIllegalOp<OpT>(); 626 addIllegalOp<OpT2, OpTs...>(); 627 } 628 629 /// Mark an operation, that *must* have either been set as `Legal` or 630 /// `DynamicallyLegal`, as being recursively legal. This means that in 631 /// addition to the operation itself, all of the operations nested within are 632 /// also considered legal. An optional dynamic legality callback may be 633 /// provided to mark subsets of legal instances as recursively legal. 634 template <typename OpT> 635 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) { 636 OperationName opName(OpT::getOperationName(), &ctx); 637 markOpRecursivelyLegal(opName, callback); 638 } 639 template <typename OpT, typename OpT2, typename... OpTs> 640 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) { 641 markOpRecursivelyLegal<OpT>(callback); 642 markOpRecursivelyLegal<OpT2, OpTs...>(callback); 643 } 644 template <typename OpT, class Callable> 645 typename std::enable_if< 646 !llvm::is_invocable<Callable, Operation *>::value>::type markOpRecursivelyLegal(Callable && callback)647 markOpRecursivelyLegal(Callable &&callback) { 648 markOpRecursivelyLegal<OpT>( 649 [=](Operation *op) { return callback(cast<OpT>(op)); }); 650 } 651 652 /// Register a legality action for the given dialects. 653 void setDialectAction(ArrayRef<StringRef> dialectNames, 654 LegalizationAction action); 655 656 /// Register the operations of the given dialects as legal. 657 template <typename... Names> addLegalDialect(StringRef name,Names...names)658 void addLegalDialect(StringRef name, Names... names) { 659 SmallVector<StringRef, 2> dialectNames({name, names...}); 660 setDialectAction(dialectNames, LegalizationAction::Legal); 661 } addLegalDialect()662 template <typename... Args> void addLegalDialect() { 663 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); 664 setDialectAction(dialectNames, LegalizationAction::Legal); 665 } 666 667 /// Register the operations of the given dialects as dynamically legal, i.e. 668 /// requiring custom handling by the target via 'isDynamicallyLegal'. 669 template <typename... Names> addDynamicallyLegalDialect(StringRef name,Names...names)670 void addDynamicallyLegalDialect(StringRef name, Names... names) { 671 SmallVector<StringRef, 2> dialectNames({name, names...}); 672 setDialectAction(dialectNames, LegalizationAction::Dynamic); 673 } 674 template <typename... Args> 675 void addDynamicallyLegalDialect( 676 Optional<DynamicLegalityCallbackFn> callback = llvm::None) { 677 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); 678 setDialectAction(dialectNames, LegalizationAction::Dynamic); 679 if (callback) 680 setLegalityCallback(dialectNames, *callback); 681 } 682 template <typename... Args> addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)683 void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) { 684 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); 685 setDialectAction(dialectNames, LegalizationAction::Dynamic); 686 setLegalityCallback(dialectNames, callback); 687 } 688 689 /// Register unknown operations as dynamically legal. For operations(and 690 /// dialects) that do not have a set legalization action, treat them as 691 /// dynamically legal and invoke the given callback if valid or 692 /// 'isDynamicallyLegal'. markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn & fn)693 void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) { 694 unknownOpsDynamicallyLegal = true; 695 unknownLegalityFn = fn; 696 } markUnknownOpDynamicallyLegal()697 void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; } 698 699 /// Register the operations of the given dialects as illegal, i.e. 700 /// operations of this dialect are not supported by the target. 701 template <typename... Names> addIllegalDialect(StringRef name,Names...names)702 void addIllegalDialect(StringRef name, Names... names) { 703 SmallVector<StringRef, 2> dialectNames({name, names...}); 704 setDialectAction(dialectNames, LegalizationAction::Illegal); 705 } addIllegalDialect()706 template <typename... Args> void addIllegalDialect() { 707 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); 708 setDialectAction(dialectNames, LegalizationAction::Illegal); 709 } 710 711 //===--------------------------------------------------------------------===// 712 // Legality Querying 713 //===--------------------------------------------------------------------===// 714 715 /// Get the legality action for the given operation. 716 Optional<LegalizationAction> getOpAction(OperationName op) const; 717 718 /// If the given operation instance is legal on this target, a structure 719 /// containing legality information is returned. If the operation is not 720 /// legal, None is returned. 721 Optional<LegalOpDetails> isLegal(Operation *op) const; 722 723 protected: 724 /// Runs a custom legalization query for the given operation. This should 725 /// return true if the given operation is legal, otherwise false. isDynamicallyLegal(Operation * op)726 virtual bool isDynamicallyLegal(Operation *op) const { 727 llvm_unreachable( 728 "targets with custom legalization must override 'isDynamicallyLegal'"); 729 } 730 731 private: 732 /// Set the dynamic legality callback for the given operation. 733 void setLegalityCallback(OperationName name, 734 const DynamicLegalityCallbackFn &callback); 735 736 /// Set the dynamic legality callback for the given dialects. 737 void setLegalityCallback(ArrayRef<StringRef> dialects, 738 const DynamicLegalityCallbackFn &callback); 739 740 /// Set the recursive legality callback for the given operation and mark the 741 /// operation as recursively legal. 742 void markOpRecursivelyLegal(OperationName name, 743 const DynamicLegalityCallbackFn &callback); 744 745 /// The set of information that configures the legalization of an operation. 746 struct LegalizationInfo { 747 /// The legality action this operation was given. 748 LegalizationAction action; 749 750 /// If some legal instances of this operation may also be recursively legal. 751 bool isRecursivelyLegal; 752 753 /// The legality callback if this operation is dynamically legal. 754 Optional<DynamicLegalityCallbackFn> legalityFn; 755 }; 756 757 /// Get the legalization information for the given operation. 758 Optional<LegalizationInfo> getOpInfo(OperationName op) const; 759 760 /// A deterministic mapping of operation name and its respective legality 761 /// information. 762 llvm::MapVector<OperationName, LegalizationInfo> legalOperations; 763 764 /// A set of legality callbacks for given operation names that are used to 765 /// check if an operation instance is recursively legal. 766 DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns; 767 768 /// A deterministic mapping of dialect name to the specific legality action to 769 /// take. 770 llvm::StringMap<LegalizationAction> legalDialects; 771 772 /// A set of dynamic legality callbacks for given dialect names. 773 llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns; 774 775 /// An optional legality callback for unknown operations. 776 Optional<DynamicLegalityCallbackFn> unknownLegalityFn; 777 778 /// Flag indicating if unknown operations should be treated as dynamically 779 /// legal. 780 bool unknownOpsDynamicallyLegal; 781 782 /// The current context this target applies to. 783 MLIRContext &ctx; 784 }; 785 786 //===----------------------------------------------------------------------===// 787 // Op Conversion Entry Points 788 //===----------------------------------------------------------------------===// 789 790 /// Below we define several entry points for operation conversion. It is 791 /// important to note that the patterns provided to the conversion framework may 792 /// have additional constraints. See the `PatternRewriter Hooks` section of the 793 /// ConversionPatternRewriter, to see what additional constraints are imposed on 794 /// the use of the PatternRewriter. 795 796 /// Apply a partial conversion on the given operations and all nested 797 /// operations. This method converts as many operations to the target as 798 /// possible, ignoring operations that failed to legalize. This method only 799 /// returns failure if there ops explicitly marked as illegal. If an 800 /// `unconvertedOps` set is provided, all operations that are found not to be 801 /// legalizable to the given `target` are placed within that set. (Note that if 802 /// there is an op explicitly marked as illegal, the conversion terminates and 803 /// the `unconvertedOps` set will not necessarily be complete.) 804 LLVM_NODISCARD LogicalResult 805 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target, 806 const FrozenRewritePatternList &patterns, 807 DenseSet<Operation *> *unconvertedOps = nullptr); 808 LLVM_NODISCARD LogicalResult 809 applyPartialConversion(Operation *op, ConversionTarget &target, 810 const FrozenRewritePatternList &patterns, 811 DenseSet<Operation *> *unconvertedOps = nullptr); 812 813 /// Apply a complete conversion on the given operations, and all nested 814 /// operations. This method returns failure if the conversion of any operation 815 /// fails, or if there are unreachable blocks in any of the regions nested 816 /// within 'ops'. 817 LLVM_NODISCARD LogicalResult 818 applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target, 819 const FrozenRewritePatternList &patterns); 820 LLVM_NODISCARD LogicalResult 821 applyFullConversion(Operation *op, ConversionTarget &target, 822 const FrozenRewritePatternList &patterns); 823 824 /// Apply an analysis conversion on the given operations, and all nested 825 /// operations. This method analyzes which operations would be successfully 826 /// converted to the target if a conversion was applied. All operations that 827 /// were found to be legalizable to the given 'target' are placed within the 828 /// provided 'convertedOps' set; note that no actual rewrites are applied to the 829 /// operations on success and only pre-existing operations are added to the set. 830 /// This method only returns failure if there are unreachable blocks in any of 831 /// the regions nested within 'ops'. 832 LLVM_NODISCARD LogicalResult 833 applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target, 834 const FrozenRewritePatternList &patterns, 835 DenseSet<Operation *> &convertedOps); 836 LLVM_NODISCARD LogicalResult 837 applyAnalysisConversion(Operation *op, ConversionTarget &target, 838 const FrozenRewritePatternList &patterns, 839 DenseSet<Operation *> &convertedOps); 840 } // end namespace mlir 841 842 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ 843