1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/Linalg/IR/Linalg.h"
23 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
26
27 namespace tensorflow {
28 namespace {
29
30 #define GEN_PASS_CLASSES
31 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
32
33 using mlir::failure;
34 using mlir::LogicalResult;
35 using mlir::MLIRContext;
36 using mlir::PatternRewriter;
37 using mlir::RewritePatternSet;
38 using mlir::success;
39 using mlir::arith::ConstantIndexOp;
40 using mlir::gml_st::LoopOp;
41 using mlir::linalg::FillOp;
42 using mlir::linalg::GenericOp;
43 using mlir::tensor::ExpandShapeOp;
44 using mlir::vector::TransferReadOp;
45 using mlir::vector::TransferWriteOp;
46
47 // The upper limit for vectorization of untiled `linalg.fill`. If a tensor has a
48 // static shape with more elements, then `linalg.fill` won't be vectorized. It
49 // is expected that such operations are tiled to get to small static shapes.
50 constexpr int64_t kNumElementsThreshold = 1024;
51
52 // Rewrite `vector.transfer_read(linalg.expand_shape)` as
53 // `vector.shape_cast(vector.transfer_read)`.
54 struct TransferReadOfOneDimExpandShape
55 : public mlir::OpRewritePattern<TransferReadOp> {
56 using OpRewritePattern<TransferReadOp>::OpRewritePattern;
57
matchAndRewritetensorflow::__anon2c1e8bb10111::TransferReadOfOneDimExpandShape58 mlir::LogicalResult matchAndRewrite(
59 TransferReadOp vector_read,
60 mlir::PatternRewriter &rewriter) const override {
61 auto expand = vector_read.getSource().getDefiningOp<ExpandShapeOp>();
62 if (!expand) return failure();
63
64 auto expand_src = expand.getSrc();
65 auto expand_src_type = expand.getSrcType();
66 auto expand_dst_type = expand.getResultType();
67 if (expand_src_type.getRank() != 1 || expand_dst_type.getRank() != 2)
68 return failure();
69
70 auto result_type = vector_read.getType().dyn_cast<mlir::ShapedType>();
71 if (!result_type || result_type.getShape() != expand_dst_type.getShape())
72 return failure();
73
74 auto zero = rewriter.create<ConstantIndexOp>(vector_read.getLoc(), 0);
75 auto map = mlir::AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0)},
76 vector_read.getContext());
77 // TODO(pifon): Also support canonicalization in case the map is not an
78 // identity.
79 if (!map.isIdentity()) return failure();
80
81 auto new_read = rewriter.create<TransferReadOp>(
82 vector_read.getLoc(),
83 mlir::VectorType::get(expand_src_type.getShape(),
84 expand_src_type.getElementType()),
85 expand_src, mlir::ValueRange{zero}, mlir::AffineMapAttr::get(map),
86 vector_read.getPadding(),
87 /*mask=*/mlir::Value(), rewriter.getBoolArrayAttr({true}));
88 rewriter.replaceOpWithNewOp<mlir::vector::ShapeCastOp>(
89 vector_read, vector_read.getType(), new_read);
90 return success();
91 }
92 };
93
94 template <typename OpTy>
95 struct VectorizationPattern : public mlir::OpRewritePattern<OpTy> {
VectorizationPatterntensorflow::__anon2c1e8bb10111::VectorizationPattern96 VectorizationPattern(MLIRContext *context,
97 llvm::function_ref<bool(OpTy)> match_fn,
98 mlir::PatternBenefit benefit = 1)
99 : mlir::OpRewritePattern<OpTy>(context, benefit), match_fn(match_fn) {}
100
matchAndRewritetensorflow::__anon2c1e8bb10111::VectorizationPattern101 LogicalResult matchAndRewrite(OpTy op,
102 PatternRewriter &rewriter) const override {
103 if (!match_fn(op)) return failure();
104 return mlir::linalg::vectorize(rewriter, op);
105 }
106
107 private:
108 llvm::function_ref<bool(OpTy)> match_fn;
109 };
110
getDefaultVectorizationPatterns(MLIRContext * ctx)111 RewritePatternSet getDefaultVectorizationPatterns(MLIRContext *ctx) {
112 RewritePatternSet patterns(ctx);
113 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
114 mlir::vector::populateVectorReductionToContractPatterns(patterns);
115 patterns.add<mlir::linalg::LinalgCopyVTRForwardingPattern,
116 mlir::linalg::LinalgCopyVTWForwardingPattern>(ctx,
117 /*benefit=*/2);
118 TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
119 TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
120 return patterns;
121 }
122
isFillTiledOrSmall(FillOp fill)123 bool isFillTiledOrSmall(FillOp fill) {
124 if (fill->getParentOfType<LoopOp>()) return true;
125
126 // Allow vectorization for static shapes with low number of elements.
127 auto output_type = fill.output().getType().cast<mlir::RankedTensorType>();
128 return output_type.hasStaticShape() &&
129 output_type.getNumElements() < kNumElementsThreshold;
130 }
131
isGenericOpTiledOrOneDimReduction(GenericOp generic)132 bool isGenericOpTiledOrOneDimReduction(GenericOp generic) {
133 if (generic->getParentOfType<LoopOp>()) return true;
134
135 // Allow vectorization of 1D reductions.
136 return generic.getNumLoops() == 1 && generic.getNumReductionLoops() == 1;
137 }
138
139 struct VectorizeTiledOpsPass
140 : public VectorizeTiledOpsBase<VectorizeTiledOpsPass> {
getDependentDialectstensorflow::__anon2c1e8bb10111::VectorizeTiledOpsPass141 void getDependentDialects(mlir::DialectRegistry ®istry) const override {
142 registry.insert<mlir::vector::VectorDialect>();
143 }
144
runOnOperationtensorflow::__anon2c1e8bb10111::VectorizeTiledOpsPass145 void runOnOperation() override {
146 auto func = getOperation();
147 auto ctx = func.getContext();
148
149 RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx);
150 patterns.add<TransferReadOfOneDimExpandShape>(func.getContext());
151 patterns.add<VectorizationPattern<FillOp>>(ctx, isFillTiledOrSmall);
152 patterns.add<VectorizationPattern<GenericOp>>(
153 ctx, isGenericOpTiledOrOneDimReduction);
154 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
155 }
156 };
157
158 } // namespace
159
160 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateVectorizeTiledOpsPass()161 CreateVectorizeTiledOpsPass() {
162 return std::make_unique<VectorizeTiledOpsPass>();
163 }
164
165 } // namespace tensorflow
166