1 //===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ 10 #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ 11 12 #include "mlir/Dialect/Linalg/Utils/Utils.h" 13 #include "mlir/Dialect/Vector/VectorOps.h" 14 #include "mlir/IR/Identifier.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Transforms/Bufferize.h" 17 #include "llvm/ADT/SmallBitVector.h" 18 #include "llvm/ADT/SmallSet.h" 19 20 namespace mlir { 21 class BufferizeTypeConverter; 22 class FrozenRewritePatternList; 23 24 namespace linalg { 25 26 struct LinalgFusionOptions; 27 struct LinalgTilingOptions; 28 29 //===----------------------------------------------------------------------===// 30 // Transformations exposed as function calls. 31 //===----------------------------------------------------------------------===// 32 using LinalgLoops = SmallVector<Operation *, 4>; 33 34 struct TiledLinalgOp { 35 LinalgOp op; 36 SmallVector<Operation *, 8> loops; 37 SmallVector<Value, 4> tensorResults; 38 }; 39 40 /// Populates patterns for vectorization of all ConvN-D ops. 41 void populateConvVectorizationPatterns( 42 MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, 43 ArrayRef<int64_t> tileSizes); 44 45 /// Populates the given list with patterns to bufferize linalg ops. 46 void populateLinalgBufferizePatterns(MLIRContext *context, 47 BufferizeTypeConverter &converter, 48 OwningRewritePatternList &patterns); 49 50 /// Performs standalone tiling of a single LinalgOp by `tileSizes`. 51 /// and permute the loop nest according to `interchangeVector` 52 /// The permutation is expressed as a list of integers that specify 53 /// the new ordering of the loop nest. The length of `interchangeVector` 54 /// must be equal to the length of `tileSizes`. 55 /// An empty vector is interpreted as the identity permutation and the 56 /// transformation returns early. 57 /// 58 /// Returns a struct containing the tiled loops in the specified order 59 /// and the cloned op if successful, llvm::None otherwise. 60 /// 61 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by 62 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be 63 /// integers, in the range 0..`tileSizes.size()` without duplications 64 /// (i.e. `[1,1,2]` is an invalid permutation). 65 Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op, 66 const LinalgTilingOptions &options); 67 68 /// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This 69 /// proceeds as follows: 70 /// - Find outer parallel loops in these ops that can be fused. 71 /// - Tile fusable outer parallel loops of the last operation in the sequence. 72 /// - Fuse the remaining operations with the tiled operation 73 /// 74 /// For example, consider the sequence of matmul below 75 /// 76 /// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>) 77 /// outs(%arg2 : memref<256x32xf32>) 78 /// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>) 79 /// outs(%arg4 : memref<256x32xf32>) 80 /// 81 /// It is legal to fuse the RAW dependence (through %arg2) by only fusing the 82 /// matmuls row-wise. For example, the fused computation for the above is shown 83 /// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling 84 /// along the rows of the matrix. The entire rows of the first matmul operation 85 /// need to be computed before they can be used for the second matmul. The 86 /// second matmul is further tiled (similar to normal tiling). 87 /// 88 /// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> 89 /// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)> 90 /// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) { 91 /// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1] 92 /// : memref<256x32xf32> to memref<16x32xf32, #map0> 93 /// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1] 94 /// : memref<256x32xf32> to memref<16x32xf32, #map0> 95 /// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1] 96 /// : memref<256x32xf32> to memref<16x32xf32, #map0> 97 /// %3 = subview %arg1[0, 0] [32, 32] [1, 1] 98 /// : memref<32x32xf32> to memref<32x32xf32, #map1> 99 /// %4 = subview %arg3[0, 0] [32, 32] [1, 1] 100 /// : memref<32x32xf32> to memref<32x32xf32, #map1> 101 /// linalg.matmul 102 /// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) 103 /// outs(%0 : memref<16x32xf32, #map0>) 104 /// linalg.matmul 105 /// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) 106 /// outs(%1 : memref<16x8xf32, #map0>) 107 /// } 108 /// 109 /// `tilingOptions` are used to tile the corresponding operation in `ops` (the 110 /// size of the former should be same as size of the latter. Based on how 111 /// tile+fuse is implemented, the fused loops are generated based on the last 112 /// operation in the sequence. For example, the tile sizes for the fused loops 113 /// is obtained from `tilingOptions.back()`. The following tiling options are 114 /// handled differently in tile+fuse (compared to tile only) 115 /// - Interchange of the tiling loops is not supported right now. 116 /// - Only the fused loops are distributed. 117 struct TiledAndFusedLinalgOps { 118 /// Operation obtained by tiling the last operation in sequence of `ops` 119 /// passed to `tileAndFuseLinalgOps`. 120 LinalgOp op; 121 /// The dimension of the loops that are fused. 122 std::set<unsigned> fusedLoopDims; 123 /// The generated fused operations (created within the fused loops). 124 SmallVector<LinalgOp, 1> fusedProducers; 125 /// The fused loop generated. 126 SmallVector<Operation *, 4> fusedLoops; 127 }; 128 Optional<TiledAndFusedLinalgOps> 129 tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops, 130 const LinalgDependenceGraph &dependenceGraph, 131 const LinalgTilingOptions &tilingOptions); 132 133 /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. 134 /// This is an in-place transformation controlled by `interchangeVector`. 135 /// An empty vector is interpreted as the identity permutation and the 136 /// transformation returns early. 137 /// 138 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with 139 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be 140 /// integers, in the range 0..`op.rank` without duplications 141 /// (i.e. `[1,1,2]` is an invalid permutation). 142 LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector); 143 144 /// Callback function type used to perform the allocation for the promoted 145 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the 146 /// smallest constant value for the size of the buffer needed for each 147 /// dimension. If that is not possible, contains the dynamic size of the 148 /// subview. The call back should return the buffer to use. 149 using AllocBufferCallbackFn = std::function<Optional<Value>( 150 OpBuilder &b, SubViewOp subView, ArrayRef<Value> boundingSubViewSize, 151 OperationFolder *folder)>; 152 153 /// Callback function type used to deallocate the buffers used to hold the 154 /// promoted subview. 155 using DeallocBufferCallbackFn = 156 std::function<LogicalResult(OpBuilder &b, Value buffer)>; 157 158 /// Callback function type used to insert copy from original subview to subview 159 /// of the promoted region for the read operands/subview of promoted region to 160 /// original subview for the results. The copy has to happen from `src` to 161 /// `dst`. 162 using CopyCallbackFn = 163 std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>; 164 165 struct LinalgPromotionOptions { 166 /// Indices of subViews to promote. If `None`, try to promote all operands. 167 Optional<DenseSet<unsigned>> operandsToPromote = None; setOperandsToPromoteLinalgPromotionOptions168 LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) { 169 operandsToPromote = DenseSet<unsigned>(); 170 operandsToPromote->insert(operands.begin(), operands.end()); 171 return *this; 172 } 173 /// If ith element of `useFullTiles` is true the full view should be used for 174 /// the promoted buffer of the ith operand in `operandsToPromote`. Otherwise 175 /// the partial view will be used. 176 /// The decision is defaulted to `useFullTileBuffersDefault` when 177 /// `useFullTileBuffers` is None and for operands missing from 178 /// `useFullTileBuffers`. 179 Optional<llvm::SmallBitVector> useFullTileBuffers = None; setUseFullTileBuffersLinalgPromotionOptions180 LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) { 181 unsigned size = useFullTiles.size(); 182 llvm::SmallBitVector tmp(size, false); 183 for (unsigned i = 0; i < size; ++i) 184 tmp[i] = useFullTiles[i]; 185 useFullTileBuffers = tmp; 186 return *this; 187 } 188 /// If true all operands unspecified by `useFullTileBuffers` will use the full 189 /// view, otherwise the partial view. 190 bool useFullTileBuffersDefault = false; setUseFullTileBuffersByDefaultLinalgPromotionOptions191 LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) { 192 useFullTileBuffersDefault = use; 193 return *this; 194 } 195 /// Allow the use of dynamically-sized buffers. 196 bool dynamicBuffers = false; setDynamicBuffersLinalgPromotionOptions197 LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) { 198 dynamicBuffers = dynamic; 199 return *this; 200 } 201 /// Alignment of promoted buffer. If `None` do not specify alignment. 202 Optional<unsigned> alignment = None; setAlignmentLinalgPromotionOptions203 LinalgPromotionOptions &setAlignment(unsigned align) { 204 alignment = align; 205 return *this; 206 } 207 /// Use alloca with the default allocation scheme. 208 bool useAlloca = false; setUseAllocaLinalgPromotionOptions209 LinalgPromotionOptions &setUseAlloca(bool use) { 210 useAlloca = use; 211 return *this; 212 } 213 /// Callback function to do the allocation of the promoted buffer. If None, 214 /// then the default allocation scheme of allocating a memref<?xi8> buffer 215 /// followed by a view operation is used. 216 Optional<AllocBufferCallbackFn> allocationFn = None; 217 Optional<DeallocBufferCallbackFn> deallocationFn = None; 218 LinalgPromotionOptions & setAllocationDeallocationFnsLinalgPromotionOptions219 setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, 220 DeallocBufferCallbackFn const &deallocFn) { 221 allocationFn = allocFn; 222 deallocationFn = deallocFn; 223 return *this; 224 } 225 /// Callback function to do the copy of data to and from the promoted 226 /// subview. If None then a linalg.copy is used. 227 Optional<CopyCallbackFn> copyInFn = None; 228 Optional<CopyCallbackFn> copyOutFn = None; setCopyInOutFnsLinalgPromotionOptions229 LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const ©In, 230 CopyCallbackFn const ©Out) { 231 copyInFn = copyIn; 232 copyOutFn = copyOut; 233 return *this; 234 } 235 }; 236 237 /// Creates a new buffer using the `allocationFn` provided. The size of this 238 /// buffer is the smallest constant bounding size along each dimension that can 239 /// be computed for the size of the result of `subView`. Returns the allocated 240 /// buffer as `fullLocalView` and the view that matches the size of the result 241 /// of subview operation as `partialLocalView`. 242 struct PromotionInfo { 243 Value fullLocalView; 244 Value partialLocalView; 245 }; 246 Optional<PromotionInfo> 247 promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView, 248 AllocBufferCallbackFn allocationFn, 249 OperationFolder *folder = nullptr); 250 251 /// Promotes the `subViews` into a new buffer allocated at the insertion point 252 /// `b`. Promotion occurs in 3 steps: 253 /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). 254 /// 2. Take a full view on the buffer. 255 /// 3. Take a partial slice of the full view in step 2. and copy into it. 256 /// Infers statically sized buffers from subViews unless `dynamicBuffers` is 257 /// true. 258 /// 259 /// Returns the modified linalg op (the modification happens in place) as well 260 /// as all the copy ops created. 261 Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op, 262 LinalgPromotionOptions options, 263 OperationFolder *folder = nullptr); 264 265 /// Emit a suitable vector form for a Linalg op with fully static shape. 266 void vectorizeLinalgOp(OpBuilder &builder, Operation *op); 267 268 /// Emits a loop nest of `LoopTy` with the proper body for `op`. 269 template <typename LoopTy> 270 Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op); 271 272 /// Emits a loop nest of `scf.for` with the proper body for `op`. 273 LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); 274 275 /// Emits a loop nest of `scf.parallel` with the proper body for `op`. 276 LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); 277 278 /// Emits a loop nest of `affine.for` with the proper body for `op`. 279 LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); 280 281 //===----------------------------------------------------------------------===// 282 // Preconditions that ensure the corresponding transformation succeeds and can 283 // be applied as a rewrite pattern. 284 //===----------------------------------------------------------------------===// 285 /// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` 286 /// and `iterator_types` permutated according to `permutation`. 287 LogicalResult 288 interchangeGenericLinalgOpPrecondition(Operation *op, 289 ArrayRef<unsigned> interchangeVector); 290 291 /// Promote std.subviews feeding linalg operations. 292 LogicalResult promoteSubviewsPrecondition(Operation *op, 293 LinalgPromotionOptions options); 294 295 /// Rewrite a linalg.generic into a suitable vector.contraction op. 296 LogicalResult vectorizeLinalgOpPrecondition(Operation *op); 297 298 //===----------------------------------------------------------------------===// 299 // Transformations exposed as rewrite patterns. 300 //===----------------------------------------------------------------------===// 301 // Marker used as attribute name in generated Linalg rewriting transformations. 302 struct LinalgTransforms { 303 static const StringLiteral kLinalgTransformMarker; 304 }; 305 306 /// Helper class to control common attribute matching and setting behavior. 307 struct LinalgMarker { 308 explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {}, 309 Optional<Identifier> replacement = None); 310 LinalgMarker(LinalgMarker &&) = default; 311 LinalgMarker(const LinalgMarker &) = default; 312 LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; 313 void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const; 314 315 private: 316 SmallVector<Identifier, 4> matchDisjunction; 317 Optional<Identifier> replacement; 318 }; 319 320 /// 321 /// Linalg tiling patterns. 322 /// 323 /// Apply the `tileLinalgOp` transformation as a pattern. 324 /// `marker` controls LinalgTransformMarker matching and update when specified. 325 /// See `tileLinalgOp` for more details. 326 enum class LinalgTilingLoopType { 327 Loops = 0, 328 AffineLoops = 1, 329 ParallelLoops = 2, 330 }; 331 332 using TileSizeComputationFunction = 333 std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>; 334 335 struct LinalgTilingOptions { 336 /// Computation function that returns the tile sizes for each operation. 337 /// Delayed construction of constant tile sizes should occur to interoperate 338 /// with folding. 339 TileSizeComputationFunction tileSizeComputationFunction = nullptr; 340 341 LinalgTilingOptions & setTileSizeComputationFunctionLinalgTilingOptions342 setTileSizeComputationFunction(TileSizeComputationFunction fun) { 343 tileSizeComputationFunction = std::move(fun); 344 return *this; 345 } 346 /// Set the `tileSizeComputationFunction` to return the values `ts`. The 347 /// values must not fold away when tiling. Otherwise, use a more robust 348 /// `tileSizeComputationFunction`. setTileSizesLinalgTilingOptions349 LinalgTilingOptions &setTileSizes(SmallVector<Value, 4> ts) { 350 tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; 351 return *this; 352 } 353 /// Convenience function to set the `tileSizeComputationFunction` to a 354 /// function that computes tile sizes at the point they are needed. Allows 355 /// proper interaction with folding. 356 LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts); 357 358 /// The interchange vector to reorder the tiled loops. 359 SmallVector<unsigned, 4> interchangeVector = {}; 360 setInterchangeLinalgTilingOptions361 LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) { 362 interchangeVector.assign(interchange.begin(), interchange.end()); 363 return *this; 364 } 365 366 /// The type of tile loops to generate. 367 LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops; 368 setLoopTypeLinalgTilingOptions369 LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) { 370 loopType = lt; 371 return *this; 372 } 373 374 /// When specified, specifies distribution of generated tile loops to 375 /// processors. 376 Optional<LinalgLoopDistributionOptions> distribution = None; 377 378 LinalgTilingOptions & setDistributionOptionsLinalgTilingOptions379 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { 380 distribution = std::move(distributionOptions); 381 return *this; 382 } 383 }; 384 385 /// Canonicalization patterns relevant to apply after tiling patterns. These are 386 /// applied automatically by the tiling pass but need to be applied manually 387 /// when tiling is called programmatically. 388 OwningRewritePatternList 389 getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); 390 void populateLinalgTilingCanonicalizationPatterns( 391 OwningRewritePatternList &patterns, MLIRContext *ctx); 392 393 struct LinalgBaseTilingPattern : public RewritePattern { 394 // Entry point to match any LinalgOp OpInterface. 395 LinalgBaseTilingPattern(LinalgTilingOptions options, 396 LinalgMarker marker = LinalgMarker(), 397 PatternBenefit benefit = 1); 398 // Entry point to match a specific Linalg op. 399 LinalgBaseTilingPattern(StringRef opName, MLIRContext *context, 400 LinalgTilingOptions options, 401 LinalgMarker marker = LinalgMarker(), 402 PatternBenefit benefit = 1); 403 LogicalResult 404 matchAndRewriteBase(Operation *op, PatternRewriter &rewriter, 405 SmallVectorImpl<Value> &tensorResults) const; 406 407 private: 408 /// LinalgTransformMarker handles special attribute manipulations. 409 LinalgMarker marker; 410 /// Options to control tiling; 411 LinalgTilingOptions options; 412 }; 413 414 template <typename OpTy> 415 struct LinalgTilingPattern : public LinalgBaseTilingPattern { 416 LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options, 417 LinalgMarker marker = LinalgMarker(), 418 PatternBenefit benefit = 1) LinalgBaseTilingPatternLinalgTilingPattern419 : LinalgBaseTilingPattern(OpTy::getOperationName(), context, options, 420 marker, benefit) {} matchAndRewriteLinalgTilingPattern421 LogicalResult matchAndRewrite(Operation *op, 422 PatternRewriter &rewriter) const override { 423 SmallVector<Value, 4> tensorResults; 424 if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter, 425 tensorResults))) 426 return failure(); 427 if (tensorResults.empty()) 428 rewriter.eraseOp(op); 429 else 430 rewriter.replaceOp(op, tensorResults); 431 return success(); 432 } 433 }; 434 435 struct LinalgFusionOptions { 436 /// List of operands indices to use for fusion. 437 llvm::SmallSet<unsigned, 1> indicesToFuse = {}; setIndicesToFuseLinalgFusionOptions438 LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) { 439 indicesToFuse.insert(operands.begin(), operands.end()); 440 return *this; 441 } 442 }; 443 444 struct LinalgBaseTileAndFusePattern : public RewritePattern { 445 LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context, 446 const LinalgDependenceGraph &dependenceGraph, 447 LinalgTilingOptions tilingOptions, 448 LinalgFusionOptions fusionOptions, 449 LinalgMarker marker = LinalgMarker(), 450 LinalgMarker fusedOpMarker = LinalgMarker(), 451 LinalgMarker originalOpMarker = LinalgMarker(), 452 PatternBenefit benefit = 1); 453 LogicalResult matchAndRewrite(Operation *op, 454 PatternRewriter &rewriter) const override; 455 456 private: 457 /// Dependence graph needed for fusion. 458 const LinalgDependenceGraph &dependenceGraph; 459 /// Options to control tiling. 460 LinalgTilingOptions tilingOptions; 461 /// Options to control fusion. 462 LinalgFusionOptions fusionOptions; 463 /// Marker to control application of the pattern. 464 LinalgMarker marker; 465 /// Marker set on the fused op after tile and fuse. 466 LinalgMarker fusedOpMarker; 467 /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used 468 /// to build the dependence graph changes then the dependenceGraph needs to be 469 /// recomputed right now. To not invalidate the dependenceGraph as 470 /// transformation happens, the original producer can be tagged with a marker 471 /// that can be later used to delete the original operations. 472 LinalgMarker originalOpMarker; 473 }; 474 475 template <typename OpTy> 476 struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern { 477 LinalgTileAndFusePattern(MLIRContext *context, 478 const LinalgDependenceGraph &dependenceGraph, 479 LinalgTilingOptions tilingOptions, 480 LinalgFusionOptions fusionOptions, 481 LinalgMarker marker = LinalgMarker(), 482 LinalgMarker fusedOpMarker = LinalgMarker(), 483 LinalgMarker originalOpMarker = LinalgMarker(), 484 PatternBenefit benefit = 1) LinalgBaseTileAndFusePatternLinalgTileAndFusePattern485 : LinalgBaseTileAndFusePattern( 486 OpTy::getOperationName(), context, dependenceGraph, tilingOptions, 487 fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {} 488 }; 489 490 /// 491 /// Linalg interchange patterns. 492 /// 493 /// Apply the `interchange` transformation as a pattern. 494 /// `marker` controls LinalgTransformMarker matching and update when specified. 495 /// See `interchange` for more details. 496 struct LinalgBaseInterchangePattern : public RewritePattern { 497 LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context, 498 ArrayRef<unsigned> interchangeVector, 499 LinalgMarker marker = LinalgMarker(), 500 PatternBenefit benefit = 1); 501 LogicalResult matchAndRewrite(Operation *op, 502 PatternRewriter &rewriter) const override; 503 504 private: 505 /// LinalgTransformMarker handles special attribute manipulations. 506 LinalgMarker marker; 507 /// The interchange vector to reorder the iterators and indexing_maps dims. 508 SmallVector<unsigned, 8> interchangeVector; 509 }; 510 511 template <typename OpTy> 512 struct LinalgInterchangePattern : public LinalgBaseInterchangePattern { 513 LinalgInterchangePattern(MLIRContext *context, 514 ArrayRef<unsigned> interchangeVector, 515 LinalgMarker marker = LinalgMarker(), 516 PatternBenefit benefit = 1) LinalgBaseInterchangePatternLinalgInterchangePattern517 : LinalgBaseInterchangePattern(OpTy::getOperationName(), context, 518 interchangeVector, marker, benefit) {} 519 }; 520 521 /// 522 /// Linalg promotion patterns. 523 /// 524 /// Apply the `promoteSubViews` transformation as a pattern. 525 /// `marker` controls LinalgTransformMarker matching and update when specified. 526 /// See `promoteSubViews` for more details. 527 struct LinalgBasePromotionPattern : public RewritePattern { 528 LinalgBasePromotionPattern(StringRef opName, MLIRContext *context, 529 LinalgPromotionOptions options, 530 LinalgMarker marker = LinalgMarker(), 531 PatternBenefit benefit = 1); 532 LogicalResult matchAndRewrite(Operation *op, 533 PatternRewriter &rewriter) const override; 534 535 private: 536 /// LinalgTransformMarker handles special attribute manipulations. 537 LinalgMarker marker; 538 /// Promotion options. 539 LinalgPromotionOptions options; 540 }; 541 542 template <typename OpTy> 543 struct LinalgPromotionPattern : public LinalgBasePromotionPattern { 544 LinalgPromotionPattern(MLIRContext *context, LinalgPromotionOptions options, 545 LinalgMarker marker = LinalgMarker(), 546 PatternBenefit benefit = 1) LinalgBasePromotionPatternLinalgPromotionPattern547 : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options, 548 marker, benefit) {} 549 }; 550 551 /// 552 /// Linalg vectorization patterns. 553 /// 554 /// Apply the `vectorizeLinalgOp` transformation as a pattern. 555 /// `marker` controls LinalgTransformMarker matching and update when specified. 556 /// See `vectorizeLinalgOp` for more details. 557 struct LinalgBaseVectorizationPattern : public RewritePattern { 558 LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context, 559 LinalgMarker marker = LinalgMarker(), 560 PatternBenefit benefit = 1); 561 LogicalResult matchAndRewrite(Operation *op, 562 PatternRewriter &rewriter) const override; 563 564 private: 565 /// LinalgTransformMarker handles special attribute manipulations. 566 LinalgMarker marker; 567 }; 568 569 template <typename OpTy> 570 struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern { 571 LinalgVectorizationPattern(MLIRContext *context, 572 LinalgMarker marker = LinalgMarker(), 573 PatternBenefit benefit = 1) LinalgBaseVectorizationPatternLinalgVectorizationPattern574 : LinalgBaseVectorizationPattern(OpTy::getOperationName(), context, 575 marker, benefit) {} 576 }; 577 578 /// 579 /// Linalg lowering patterns. 580 /// 581 /// Apply the `linalgLowerOpToLoops` transformation as a pattern. 582 /// `marker` controls LinalgTransformMarker matching and update when specified. 583 /// See `linalgLowerOpToLoops` for more details. 584 enum class LinalgLoweringType { 585 LibraryCall = 0, 586 Loops = 1, 587 AffineLoops = 2, 588 ParallelLoops = 3 589 }; 590 template <typename OpTy> 591 struct LinalgLoweringPattern : public RewritePattern { 592 LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, 593 LinalgMarker marker = LinalgMarker(), 594 PatternBenefit benefit = 1) 595 : RewritePattern(OpTy::getOperationName(), {}, benefit, context), 596 marker(marker), loweringType(loweringType) {} 597 // TODO: Move implementation to .cpp once named ops are auto-generated. matchAndRewriteLinalgLoweringPattern598 LogicalResult matchAndRewrite(Operation *op, 599 PatternRewriter &rewriter) const override { 600 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 601 if (!linalgOp) 602 return failure(); 603 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 604 return failure(); 605 606 if (loweringType == LinalgLoweringType::LibraryCall) { 607 // TODO: Move lowering to library calls here. 608 return failure(); 609 } else if (loweringType == LinalgLoweringType::Loops) { 610 if (failed(linalgOpToLoops(rewriter, op))) 611 return failure(); 612 } else if (loweringType == LinalgLoweringType::AffineLoops) { 613 if (failed(linalgOpToAffineLoops(rewriter, op))) 614 return failure(); 615 } else if (failed(linalgOpToParallelLoops(rewriter, op))) { 616 return failure(); 617 } 618 rewriter.eraseOp(op); 619 return success(); 620 } 621 622 private: 623 /// LinalgTransformMarker handles special attribute manipulations. 624 LinalgMarker marker; 625 /// Controls whether the pattern lowers to library calls, scf.for, affine.for 626 /// or scf.parallel. 627 LinalgLoweringType loweringType; 628 }; 629 630 /// Linalg generalization patterns 631 632 /// Populates `patterns` with patterns to convert spec-generated named ops to 633 /// linalg.generic ops. 634 void populateLinalgNamedOpsGeneralizationPatterns( 635 MLIRContext *context, OwningRewritePatternList &patterns, 636 LinalgMarker marker = LinalgMarker()); 637 638 /// Populates `patterns` with patterns to convert linalg.conv ops to 639 /// linalg.generic ops. 640 void populateLinalgConvGeneralizationPatterns( 641 MLIRContext *context, OwningRewritePatternList &patterns, 642 LinalgMarker marker = LinalgMarker()); 643 644 //===----------------------------------------------------------------------===// 645 // Op-specific patterns. 646 //===----------------------------------------------------------------------===// 647 /// Match and rewrite for the pattern: 648 /// ``` 649 /// %alloc = ... 650 /// [optional] %view = std.view %alloc ... 651 /// %subView = subview %allocOrView ... 652 /// [optional] linalg.fill(%allocOrView, %cst) ... 653 /// ... 654 /// linalg.copy(%in, %subView) ... 655 /// vector.transfer_read %allocOrView[...], %cst ... 656 /// ``` 657 /// into 658 /// ``` 659 /// [unchanged] %alloc = ... 660 /// [unchanged] [optional] %view = std.view %alloc ... 661 /// [unchanged] [unchanged] %subView = subview %allocOrView ... 662 /// ... 663 /// vector.transfer_read %in[...], %cst ... 664 /// ``` 665 /// Where there is no interleaved use between linalg.copy and transfer_read as 666 /// well as no interleaved use between linalg.fill and linalg.copy (if 667 /// linalg.fill is specified). 668 /// This is a custom rewrite to forward partial reads (with optional fills) to 669 /// vector.transfer_read. 670 struct LinalgCopyVTRForwardingPattern 671 : public OpRewritePattern<vector::TransferReadOp> { 672 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 673 674 LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, 675 PatternRewriter &rewriter) const override; 676 }; 677 678 /// Match and rewrite for the pattern: 679 /// ``` 680 /// %alloc = ... 681 /// [optional] %view = std.view %alloc ... 682 /// %subView = subview %allocOrView... 683 /// ... 684 /// vector.transfer_write %..., %allocOrView[...] 685 /// linalg.copy(%subView, %out) 686 /// ``` 687 /// into 688 /// ``` 689 /// [unchanged] %alloc = ... 690 /// [unchanged] [optional] %view = std.view %alloc ... 691 /// [unchanged] %subView = subview %allocOrView... 692 /// ... 693 /// vector.transfer_write %..., %out[...] 694 /// ``` 695 /// Where there is no interleaved use between transfer_write and linalg.copy. 696 /// This is a custom rewrite to forward partial writes to vector.transfer_write. 697 struct LinalgCopyVTWForwardingPattern 698 : public OpRewritePattern<vector::TransferWriteOp> { 699 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 700 701 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 702 PatternRewriter &rewriter) const override; 703 }; 704 705 /// Canonicalize AffineMinOp operations in the context of enclosing scf.for and 706 /// scf.parallel by: 707 /// 1. building an affine map where uses of the induction variable of a loop 708 /// are replaced by either the min (i.e. `%lb`) of the max 709 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending 710 /// on whether the induction variable is used with a positive or negative 711 /// coefficient. 712 /// 2. checking whether any of the results of this affine map is known to be 713 /// greater than all other results. 714 /// 3. replacing the AffineMinOp by the result of (2). 715 // TODO: move to a more appropriate place when it is determined. For now Linalg 716 // depends both on Affine and SCF but they do not depend on each other. 717 struct AffineMinSCFCanonicalizationPattern 718 : public OpRewritePattern<AffineMinOp> { 719 using OpRewritePattern<AffineMinOp>::OpRewritePattern; 720 721 LogicalResult matchAndRewrite(AffineMinOp minOp, 722 PatternRewriter &rewriter) const override; 723 }; 724 725 /// Converts Convolution op into vector contraction. 726 /// 727 /// Conversion expects ConvOp to have dimensions marked in the *mask* as 728 /// false of size 1. This ensures that the ConvOp can be lowered to vector 729 /// contraction of dimensions marked in the *mask* as true. 730 /// 731 /// A good example for vectorization is ConvNHWCOp which is 2D Conv op 732 /// with channels as the last dimension. Let's vectorize last 3 dimensions. 733 /// The initial op definition looks like this: 734 /// ``` 735 /// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : 736 /// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>) 737 /// ``` 738 /// This op can be expressed as a dot product between %arg0 (input) and 739 /// %arg1 (kernel) which is written into first entry of %arg2 (output). This is 740 /// the ConvOp this pass expects and converts into: 741 /// ``` 742 /// #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 743 /// #map1 = affine_map<(d0, d1, d2) -> ()> 744 /// ..... 745 /// %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %c0_f32 746 /// : memref<1x3x3x3xf32>, vector<3x3x3xf32> 747 /// %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %c0_f32 748 /// : memref<1x3x3x3xf32>, vector<3x3x3xf32> 749 /// %2 = vector.contract {indexing_maps = [#map0, #map0, #map1], 750 /// iterator_types = ["reduction", "reduction", "reduction"]} %0, %1, 751 /// %c0_f32 : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 752 /// store %2, %arg2[%c0, %c0, %c0, %c0] : memref<?x?x?x?xf32> 753 /// ``` 754 /// where first 2 operations read input and kernel memory buffers into vectors. 755 /// Subsequently, they are contracted together and the result is written to 756 /// the first entry of the output buffer. 757 template <typename ConvOp, int N> 758 class ConvOpVectorization : public OpRewritePattern<ConvOp> { 759 using OpRewritePattern<ConvOp>::OpRewritePattern; 760 SmallVector<bool, 4> mask; 761 762 public: ConvOpVectorization(MLIRContext * context,SmallVector<bool,4> msk)763 ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk) 764 : OpRewritePattern<ConvOp>(context) { 765 assert(msk.size() == N && "Mask size does not match rank"); 766 this->mask = msk; 767 } 768 769 LogicalResult matchAndRewrite(ConvOp minOp, 770 PatternRewriter &rewriter) const override; 771 }; 772 773 //===----------------------------------------------------------------------===// 774 // Support for staged pattern application. 775 //===----------------------------------------------------------------------===// 776 /// Helper function to allow applying rewrite patterns, interleaved with more 777 /// global transformations, in a staged fashion: 778 /// 1. the first stage consists of a list of FrozenRewritePatternList. Each 779 /// FrozenRewritePatternList in this list is applied once, in order. 780 /// 2. the second stage consists of a single OwningRewritePattern that is 781 /// applied greedily until convergence. 782 /// 3. the third stage consists of applying a lambda, generally used for 783 /// non-local transformation effects. This allows creating custom fused 784 /// transformations where patterns can be ordered and applied at a finer 785 /// granularity than a sequence of traditional compiler passes. 786 LogicalResult applyStagedPatterns( 787 Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns, 788 const FrozenRewritePatternList &stage2Patterns, 789 function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr); 790 791 //===----------------------------------------------------------------------===// 792 // Support for sparse tensor code generation. 793 // 794 // The sparse compiler part of MLIR lowers a tensor expression formulated as a 795 // Linalg operation into a sequence of loops depending on what dimensions of the 796 // tensors are marked dense or sparse. The generated code distinguishes between: 797 // (1) for-loops that iterate over a single dense dimension, 798 // (2) for-loops that iterate over a single sparse dimension, 799 // (3) while-loops that co-iterate over several sparse dimensions. 800 // The for-loops may be subsequently optimized for parallel or vector execution. 801 // 802 // For more details, the Dialect/Linalg/Transforms/Sparsification.cpp file. 803 //===----------------------------------------------------------------------===// 804 805 /// Defines a parallelization strategy. Any implicit loop in the Linalg 806 /// operation that is marked "parallel" (thus not "reduction") is a candidate 807 /// for parallelization. The loop is made parallel if (1) allowed by the 808 /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse 809 /// outermost loop only), and (2) the generated code is an actual for-loop 810 /// (and not a co-iterating while-loop). 811 enum class SparseParallelizationStrategy { 812 kNone, 813 kDenseOuterLoop, 814 kAnyStorageOuterLoop, 815 kDenseAnyLoop, 816 kAnyStorageAnyLoop 817 // TODO: support reduction parallelization too? 818 }; 819 820 /// Defines a vectorization strategy. Any implicit inner loop in the Linalg 821 /// operation is a candidate (full SIMD for "parallel" loops and horizontal 822 /// SIMD for "reduction" loops). A loop is actually vectorized if (1) allowed 823 /// by the strategy, and (2) the emitted code is an actual for-loop (and not 824 /// a co-iterating while-loop). 825 enum class SparseVectorizationStrategy { 826 kNone, 827 kDenseInnerLoop, 828 kAnyStorageInnerLoop 829 }; 830 831 /// Defines a type for "pointer" and "index" storage in the sparse storage 832 /// scheme, with a choice between the native platform-dependent index width, 833 /// 64-bit integers, or 32-bit integers. A narrow width obviously reduces 834 /// the memory footprint of the sparse storage scheme, but the width should 835 /// suffice to define the total required range (viz. the maximum number of 836 /// stored entries per indirection level for the "pointers" and the maximum 837 /// value of each tensor index over all dimensions for the "indices"). 838 enum class SparseIntType { kNative, kI64, kI32 }; 839 840 /// Sparsification options. 841 struct SparsificationOptions { SparsificationOptionsSparsificationOptions842 SparsificationOptions(SparseParallelizationStrategy p, 843 SparseVectorizationStrategy v, unsigned vl, 844 SparseIntType pt, SparseIntType it) 845 : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), 846 ptrType(pt), indType(it) {} SparsificationOptionsSparsificationOptions847 SparsificationOptions() 848 : SparsificationOptions(SparseParallelizationStrategy::kNone, 849 SparseVectorizationStrategy::kNone, 1u, 850 SparseIntType::kNative, SparseIntType::kNative) {} 851 SparseParallelizationStrategy parallelizationStrategy; 852 SparseVectorizationStrategy vectorizationStrategy; 853 unsigned vectorLength; 854 SparseIntType ptrType; 855 SparseIntType indType; 856 }; 857 858 /// Set up sparsification rewriting rules with the given options. 859 void populateSparsificationPatterns( 860 MLIRContext *context, OwningRewritePatternList &patterns, 861 const SparsificationOptions &options = SparsificationOptions()); 862 863 } // namespace linalg 864 } // namespace mlir 865 866 #endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ 867