1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27
28 #define DEBUG_TYPE "linalg-drop-unit-dims"
29
30 using namespace mlir;
31 using namespace mlir::edsc;
32 using namespace mlir::edsc::intrinsics;
33 using namespace mlir::linalg;
34
35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
36 /// broadcasting. For example,
37 ///
38 /// ```mlir
39 /// #accesses = [
40 /// affine_map<(d0, d1) -> (0, d1)>,
41 /// affine_map<(d0, d1) -> (d0, 0)>,
42 /// affine_map<(d0, d1) -> (d0, d1)>
43 /// ]
44 ///
45 /// #trait = {
46 /// args_in = 2,
47 /// args_out = 1,
48 /// indexing_maps = #accesses,
49 /// iterator_types = ["parallel", "parallel"],
50 /// library_call = "some_external_fn"
51 /// }
52 ///
53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
54 /// tensor<5x5xf32>
55 /// {
56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
57 /// tensor<5xf32> into tensor<1x5xf32>
58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
59 /// tensor<5xf32> into tensor<5x1xf32>
60 /// %2 = linalg.generic #trait %0, %1 {
61 /// ^bb0(%arg2: f32, %arg3: f32):
62 /// %3 = addf %arg2, %arg3 : f32
63 /// linalg.yield %3 : f32
64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
65 /// return %2 : tensor<5x5xf32>
66 /// }
67 ///
68 /// would canonicalize to
69 ///
70 /// ```mlir
71 /// #accesses = [
72 /// affine_map<(d0, d1) -> (d1)>,
73 /// affine_map<(d0, d1) -> (d0)>,
74 /// affine_map<(d0, d1) -> (d0, d1)>
75 /// ]
76 ///
77 /// #trait = {
78 /// args_in = 2,
79 /// args_out = 1,
80 /// indexing_maps = #accesses,
81 /// iterator_types = ["parallel", "parallel"],
82 /// library_call = "some_external_fn"
83 /// }
84 ///
85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
86 /// tensor<5x5xf32>
87 /// {
88 /// %0 = linalg.generic #trait %arg0, %arg1 {
89 /// ^bb0(%arg2: f32, %arg3: f32):
90 /// %3 = addf %arg2, %arg3 : f32
91 /// linalg.yield %3 : f32
92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
93 /// return %0 : tensor<5x5xf32>
94 /// }
95
96 /// Given dims of the iteration space of a structured op that are known to be
97 /// single trip count (`unitDims`), return the indexing maps to use in the
98 /// canonicalized op with these dims removed, given the original `indexingMaps`.
replaceUnitDims(DenseSet<unsigned> & unitDims,ArrayRef<AffineMap> indexingMaps,MLIRContext * context)99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
100 ArrayRef<AffineMap> indexingMaps,
101 MLIRContext *context) {
102 if (indexingMaps.empty())
103 return nullptr;
104 unsigned numIterationDims = indexingMaps.front().getNumDims();
105 unsigned numSymbols = indexingMaps.front().getNumSymbols();
106
107 // Compute the replacement for each dim expr.
108 SmallVector<AffineExpr, 4> dimReplacements;
109 dimReplacements.reserve(numIterationDims);
110 unsigned numKeptDims = 0;
111 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
112 if (unitDims.count(dim))
113 dimReplacements.push_back(getAffineConstantExpr(0, context));
114 else
115 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
116 }
117
118 // Symbols remain the same.
119 SmallVector<AffineExpr, 4> symReplacements;
120 symReplacements.reserve(numSymbols);
121 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
122 symReplacements.push_back(getAffineSymbolExpr(symbol, context));
123
124 SmallVector<AffineMap, 4> newIndexingMaps;
125 newIndexingMaps.reserve(indexingMaps.size());
126 for (AffineMap operandMap : indexingMaps) {
127 // Expected indexing maps to have no symbols.
128 if (operandMap.getNumSymbols())
129 return nullptr;
130 newIndexingMaps.push_back(simplifyAffineMap(
131 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
132 numIterationDims - unitDims.size(),
133 numSymbols)));
134 }
135
136 // Check that the new index maps are invertible. If not, something went
137 // wrong, so abort.
138 if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
139 return nullptr;
140 return ArrayAttr::get(
141 llvm::to_vector<4>(llvm::map_range(
142 newIndexingMaps,
143 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
144 context);
145 }
146
147 /// Modify the region of indexed generic op to drop arguments corresponding to
148 /// loops that are unit trip count.
149 template <typename OpTy>
150 static LogicalResult
replaceBlockArgForUnitDimLoops(OpTy op,const DenseSet<unsigned> & unitDims,PatternRewriter & rewriterp)151 replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
152 PatternRewriter &rewriterp) {
153 return success();
154 }
155
156 template <>
replaceBlockArgForUnitDimLoops(IndexedGenericOp op,const DenseSet<unsigned> & unitDims,PatternRewriter & rewriter)157 LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
158 IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
159 PatternRewriter &rewriter) {
160 OpBuilder::InsertionGuard guard(rewriter);
161 Block *entryBlock = &op->getRegion(0).front();
162 rewriter.setInsertionPointToStart(entryBlock);
163 Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
164 for (unsigned unitDimLoop : unitDims) {
165 entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
166 }
167 SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
168 entryBlock->eraseArguments(unitDimsToErase);
169 return success();
170 }
171
172 namespace {
173 /// Pattern to fold unit-trip count loops in GenericOps.
174 // TODO: Generalize this to indexed-generic as well by modifying the region args
175 // as well.
176 template <typename GenericOpTy>
177 struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
178 using OpRewritePattern<GenericOpTy>::OpRewritePattern;
matchAndRewrite__anonc9322c8b0211::FoldUnitDimLoops179 LogicalResult matchAndRewrite(GenericOpTy op,
180 PatternRewriter &rewriter) const override {
181 SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
182 if (indexingMaps.empty())
183 return failure();
184
185 // Check if any of the iteration dimensions are unit-trip count. They will
186 // end up being unit-trip count if they are used to index into a unit-dim
187 // tensor/memref.
188 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
189 if (!invertedMap)
190 return failure();
191 SmallVector<int64_t, 4> dims;
192 for (ShapedType shapedType : op.getInputOutputShapedTypes())
193 dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
194 DenseSet<unsigned> unitDims;
195 ArrayAttr iteratorTypes = op.iterator_types();
196 for (auto expr : enumerate(invertedMap.getResults())) {
197 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
198 if (dims[dimExpr.getPosition()] == 1 &&
199 iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
200 getParallelIteratorTypeName())
201 unitDims.insert(expr.index());
202 }
203 if (unitDims.empty())
204 return failure();
205
206 // Compute the modified indexing maps.
207 MLIRContext *context = rewriter.getContext();
208 ArrayAttr newIndexingMapAttr =
209 replaceUnitDims(unitDims, indexingMaps, context);
210 if (!newIndexingMapAttr)
211 return op.emitError("unable to compute modified indexing_maps");
212
213 // Compute the iterator types of the modified op by dropping the one-trip
214 // count loops.
215 SmallVector<Attribute, 4> newIteratorTypes;
216 for (auto attr : llvm::enumerate(iteratorTypes)) {
217 if (!unitDims.count(attr.index()))
218 newIteratorTypes.push_back(attr.value());
219 }
220
221 rewriter.startRootUpdate(op);
222 op.indexing_mapsAttr(newIndexingMapAttr);
223 op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
224 replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
225 rewriter.finalizeRootUpdate(op);
226 return success();
227 }
228 };
229
230 struct UnitExtentReplacementInfo {
231 RankedTensorType type;
232 AffineMap indexMap;
233 ArrayAttr reassociation;
234 };
235 } // namespace
236
237 /// Utility function for replacing operands/results to a linalg generic
238 /// operation on tensors with unit-extent dimensions. These can be replaced with
239 /// an operand/result with the unit-extent dimension removed. This is only done
240 /// if the indexing map used to access that didimensionmension has a
241 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
242 /// Linalg op, and its `indexMap` the utility function returns:
243 /// - the new type with dimensions of size 1 removed.
244 /// - modified index map that can be used to access the replaced result/operand
245 /// - the reassociation that converts from the original tensor type to the
246 /// modified tensor type.
replaceUnitExtents(AffineMap indexMap,RankedTensorType type,MLIRContext * context)247 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
248 RankedTensorType type,
249 MLIRContext *context) {
250 ArrayRef<int64_t> shape = type.getShape();
251 ArrayRef<AffineExpr> exprs = indexMap.getResults();
252 SmallVector<AffineExpr, 2> reassociations;
253 SmallVector<Attribute, 4> reassociationMaps;
254 SmallVector<AffineExpr, 4> newIndexExprs;
255 SmallVector<int64_t, 4> newShape;
256
257 int64_t origRank = type.getRank();
258 AffineExpr zeroExpr = getAffineConstantExpr(0, context);
259 auto isUnitExtent = [&](int64_t dim) -> bool {
260 return shape[dim] == 1 && exprs[dim] == zeroExpr;
261 };
262
263 unsigned dim = 0;
264 // Fold dimensions that are unit-extent at the beginning of the tensor.
265 while (dim < origRank && isUnitExtent(dim))
266 reassociations.push_back(getAffineDimExpr(dim++, context));
267 while (dim < origRank) {
268 reassociations.push_back(getAffineDimExpr(dim, context));
269 newIndexExprs.push_back(exprs[dim]);
270 newShape.push_back(shape[dim]);
271 // Fold all following dimensions that are unit-extent.
272 while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
273 ++dim;
274 reassociations.push_back(getAffineDimExpr(dim, context));
275 }
276 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
277 origRank, /*numSymbols = */ 0, reassociations, context)));
278 reassociations.clear();
279 ++dim;
280 }
281 UnitExtentReplacementInfo info = {
282 RankedTensorType::get(newShape, type.getElementType()),
283 AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
284 newIndexExprs, context),
285 ArrayAttr::get(reassociationMaps, context)};
286 return info;
287 }
288
289 namespace {
290
291 /// Pattern to replace tensors operands/results that are unit extents.
292 template <typename GenericOpTy>
293 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
294 using OpRewritePattern<GenericOpTy>::OpRewritePattern;
matchAndRewrite__anonc9322c8b0411::ReplaceUnitExtentTensors295 LogicalResult matchAndRewrite(GenericOpTy op,
296 PatternRewriter &rewriter) const override {
297 // TODO: support init_tensors and reductions.
298 if (!op.hasTensorSemantics() || !op.init_tensors().empty())
299 return failure();
300
301 MLIRContext *context = rewriter.getContext();
302 Location loc = op.getLoc();
303
304 SmallVector<AffineMap, 4> newIndexingMaps;
305 SmallVector<ArrayAttr, 4> reassociationMaps;
306 SmallVector<ShapedType, 4> newInputOutputTypes;
307 bool doCanonicalization = false;
308 for (auto it :
309 llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) {
310 auto replacementInfo = replaceUnitExtents(
311 std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
312 context);
313 reassociationMaps.push_back(replacementInfo.reassociation);
314 newIndexingMaps.push_back(replacementInfo.indexMap);
315 newInputOutputTypes.push_back(replacementInfo.type);
316 doCanonicalization |= replacementInfo.type != std::get<1>(it);
317 }
318
319 // If the indexing maps of the result operation are not invertible (i.e. not
320 // legal), abort.
321 if (!doCanonicalization ||
322 !inversePermutation(concatAffineMaps(newIndexingMaps)))
323 return failure();
324
325 // If any operand type change, insert a reshape to convert from the original
326 // type to the new type.
327 // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
328 unsigned flattenedIdx = 0;
329 auto insertReshapes = [&](ValueRange values) {
330 SmallVector<Value, 4> res;
331 res.reserve(values.size());
332 for (auto operand : llvm::enumerate(values)) {
333 if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
334 res.push_back(operand.value());
335 else
336 res.push_back(rewriter.create<linalg::TensorReshapeOp>(
337 loc, newInputOutputTypes[flattenedIdx], operand.value(),
338 reassociationMaps[flattenedIdx]));
339 ++flattenedIdx;
340 }
341 return res;
342 };
343
344 SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
345 SmallVector<Value, 4> newOutputBuffers =
346 insertReshapes(op.output_buffers());
347 SmallVector<Value, 4> newInitTensors = insertReshapes(op.init_tensors());
348
349 // If any result type change, insert a reshape to convert from the original
350 // type to the new type.
351 SmallVector<Type, 4> resultTypes;
352 resultTypes.reserve(op.getNumResults());
353 for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
354 resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
355 GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
356 loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
357 newIndexingMaps,
358 llvm::to_vector<4>(
359 op.iterator_types().template getAsValueRange<StringAttr>()));
360 rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
361 replacementOp.region().begin());
362
363 // If any result tensor has a modified shape, then add reshape to recover
364 // the original shape.
365 SmallVector<Value, 4> resultReplacements;
366 for (auto result : llvm::enumerate(replacementOp.getResults())) {
367 unsigned index = result.index() + replacementOp.getNumOperands();
368 RankedTensorType origResultType = op.getResult(result.index())
369 .getType()
370 .template cast<RankedTensorType>();
371 if (origResultType != result.value().getType())
372 resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
373 loc, origResultType, result.value(), reassociationMaps[index]));
374 else
375 resultReplacements.push_back(result.value());
376 }
377 rewriter.replaceOp(op, resultReplacements);
378 return success();
379 }
380 };
381 } // namespace
382
383 namespace {
384 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
385 /// example:
386 ///
387 /// %0 = linalg.tensor_reshape %arg0
388 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
389 /// : tensor<2048xf32> into tensor<1x4x1x512xf32>
390 /// %1 = linalg.tensor_reshape %0
391 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
392 /// affine_map<(d0, d1, d2, d3) -> (d3)>]
393 /// : tensor<1x4x1x512xf32> into tensor<4x512xf32>
394 ///
395 /// can be replaced with
396 ///
397 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
398 /// : tensor<2048xf32> into tensor<4x512xf32>
399 ///
400 /// Similarly,
401 ///
402 /// %0 = linalg.tensor_reshape %arg0
403 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
404 /// affine_map<(d0, d1, d2, d3) -> (d3)>]
405 /// : tensor<4x512xf32> into tensor<1x4x1x512xf32>
406 /// %1 = linalg.tensor_reshape %0
407 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
408 /// : tensor<1x4x1x512xf32> into tensor<2048xf32>
409 ///
410 /// can be replaced with
411 ///
412 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
413 /// : tensor<4x512xf32> into tensor<2048xf32>
414 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
415 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
416
matchAndRewrite__anonc9322c8b0611::FoldReshapeOpWithUnitExtent417 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
418 PatternRewriter &rewriter) const override {
419 // Check that the source operand is created from a reshape as well.
420 TensorReshapeOp parentReshapeOp =
421 reshapeOp.src().getDefiningOp<TensorReshapeOp>();
422 if (!parentReshapeOp)
423 return failure();
424
425 RankedTensorType srcType = reshapeOp.getSrcType(),
426 dstType = reshapeOp.getResultType(),
427 parentSrcType = parentReshapeOp.getSrcType();
428 if (!srcType.hasStaticShape() || !dstType.hasStaticShape() ||
429 !parentSrcType.hasStaticShape() ||
430 srcType.getRank() < dstType.getRank() ||
431 parentSrcType.getRank() == dstType.getRank())
432 return failure();
433
434 // Check if the result tensor_reshape after folding the reshapeOp and
435 // parentReshapeOp are combined.
436 // If the final tensor_reshape is folding, the parentReshapeOp is
437 // introducing unit-dims, and the reshapeOp does an actual reshape.
438 // If the final tensor_reshape op is expanding, the reshapeOp is
439 // introducing unit-dims, and the parentReshapeOp does an actual reshape.
440 bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
441 ArrayRef<int64_t> expandedShape =
442 isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
443 ArrayRef<int64_t> foldedShape =
444 isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
445
446 unsigned expandedDim = 0, foldedDim = 0;
447 SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
448 foldedShape.size());
449 while (expandedDim < expandedShape.size() &&
450 foldedDim < foldedShape.size()) {
451 int64_t dstSize = foldedShape[foldedDim];
452 int64_t srcSize = expandedShape[expandedDim];
453 while (srcSize < dstSize && expandedDim < expandedShape.size()) {
454 reassociationExprs[foldedDim].push_back(
455 rewriter.getAffineDimExpr(expandedDim++));
456 srcSize *= expandedShape[expandedDim];
457 }
458 if (srcSize == dstSize) {
459 reassociationExprs[foldedDim].push_back(
460 rewriter.getAffineDimExpr(expandedDim++));
461 // If the next dim in foldedShape is not 1, treat subsequent dims in
462 // expandedShape which are 1 to be collapsed.
463 if (foldedDim == foldedShape.size() - 1 ||
464 foldedShape[foldedDim + 1] != 1) {
465 while (expandedDim < expandedShape.size() &&
466 expandedShape[expandedDim] == 1) {
467 reassociationExprs[foldedDim].push_back(
468 rewriter.getAffineDimExpr(expandedDim++));
469 }
470 }
471 } else {
472 return failure();
473 }
474 foldedDim++;
475 }
476 if (expandedDim != expandedShape.size())
477 return failure();
478
479 SmallVector<AffineMap, 4> reassociationMaps =
480 llvm::to_vector<4>(llvm::map_range(
481 reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
482 return AffineMap::get(expandedShape.size(), 0, exprs,
483 rewriter.getContext());
484 }));
485 rewriter.replaceOpWithNewOp<TensorReshapeOp>(
486 reshapeOp, dstType, parentReshapeOp.src(),
487 rewriter.getAffineMapArrayAttr(reassociationMaps));
488 return success();
489 }
490 };
491 } // namespace
492
493 /// Patterns that are used to canonicalize the use of unit-extent dims for
494 /// broadcasting.
populateLinalgFoldUnitExtentDimsPatterns(MLIRContext * context,OwningRewritePatternList & patterns)495 void mlir::populateLinalgFoldUnitExtentDimsPatterns(
496 MLIRContext *context, OwningRewritePatternList &patterns) {
497 patterns
498 .insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
499 ReplaceUnitExtentTensors<GenericOp>,
500 ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
501 TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
502 patterns.insert<FoldReshapeOpWithUnitExtent>(context);
503 }
504
505 namespace {
506 /// Pass that removes unit-extent dims within generic ops.
507 struct LinalgFoldUnitExtentDimsPass
508 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
runOnFunction__anonc9322c8b0811::LinalgFoldUnitExtentDimsPass509 void runOnFunction() override {
510 OwningRewritePatternList patterns;
511 FuncOp funcOp = getFunction();
512 MLIRContext *context = funcOp.getContext();
513 if (foldOneTripLoopsOnly)
514 patterns.insert<FoldUnitDimLoops<GenericOp>,
515 FoldUnitDimLoops<IndexedGenericOp>>(context);
516 else
517 populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
518 applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
519 }
520 };
521 } // namespace
522
523 std::unique_ptr<OperationPass<FuncOp>>
createLinalgFoldUnitExtentDimsPass()524 mlir::createLinalgFoldUnitExtentDimsPass() {
525 return std::make_unique<LinalgFoldUnitExtentDimsPass>();
526 }
527