1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "mlir/Analysis/BufferViewFlowAnalysis.h" // from @llvm-project
21 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
22 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
23 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" // from @llvm-project
24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
25 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
26 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
27 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/IR/AffineMap.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
33 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
34 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
35 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
37
38 namespace mlir {
39 namespace kernel_gen {
40 namespace transforms {
41 namespace {
42 #define GEN_PASS_CLASSES
43 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
44
45 // This function takes ForOps that contain AffineMinOps and possibly peels off
46 // the last iteration of the loop. This is done in cases where it is provable
47 // that the AffineMinOp is deterministic in all cases except the possible last
48 // iteration. Some additional cleanup is done to simplify the IR that is correct
49 // through knowledge of what this transformation is doing but would generally be
50 // unwieldy in a canonicalization-like pattern.
51 //
52 // This pass is only necessary due to inefficiencies in VectorTransferSplit that
53 // is unlikely to be fixed upstream. If that changes, this pass can be fully
54 // removed.
55 //
56 // Example:
57 // scf.for %i = %c0 to %c11 step %c2
58 // %a = affine.min(%c2, %c11-%i)
59 //
60 // Becomes:
61 // scf.for %i = %c0 to %c10 step %c2
62 // %a = %c2
63 // scf.if %one_more_iter
64 // %a = affine.min(2, %c11-%i)
65 //
66 // This is possible because we can determine that the min will always be 2
67 // except for the last iteration.
SplitSCFForOp(scf::ForOp scf_for)68 void SplitSCFForOp(scf::ForOp scf_for) {
69 // The set of following steps is:
70 // 1. Validate that there are min_ops to be modified in this function.
71 // 2. Create the boundary that decides whether the min_op evaluates to the
72 // loop's step value or to the computed value based upon the iteration value.
73 // 3. Create the primary loop that does all the work except for possibly the
74 // last iteration of the loop, and replace all relevant min_ops with the step.
75 // 4. Create the final iteration, remove the step from relevant min_ops, and
76 // additionally modify related if/else ops to have a constant condition based
77 // on what we know about this loop structure.
78
79 // Match only when the lower bound is zero and the step is constant.
80 // TODO(TPOPP): Requiring constant steps and lower bound simplifies things
81 // but isn't necesarilly needed
82 auto lower_bound_op =
83 llvm::dyn_cast<ConstantOp>(scf_for.lowerBound().getDefiningOp());
84 if (!lower_bound_op) {
85 return;
86 }
87 auto lower_bound_value = lower_bound_op.getValue().dyn_cast<IntegerAttr>();
88 if (!lower_bound_value || lower_bound_value.getInt() != 0) {
89 return;
90 }
91
92 auto step_bound_op =
93 llvm::dyn_cast<ConstantOp>(scf_for.step().getDefiningOp());
94 if (!step_bound_op) {
95 return;
96 }
97 auto step_bound_value = step_bound_op.getValue().dyn_cast<IntegerAttr>();
98 if (!step_bound_value) {
99 return;
100 }
101
102 auto loc = scf_for.getLoc();
103 ImplicitLocOpBuilder b(loc, scf_for);
104
105 // This function will determine if the min_op is an operation that can be
106 // transformed after loop splitting. This relies on the function that the op
107 // represents relative to the induction variable in its loop and the
108 // bounds of the original for loop.
109 auto is_op_of_interest = [&](AffineMinOp min_op, Value iv) {
110 bool min_by_step = false;
111 for (auto i : min_op.getAffineMap().getResults()) {
112 if (i == b.getAffineConstantExpr(step_bound_value.getInt())) {
113 min_by_step = true;
114 continue;
115 }
116 if (i == b.getAffineSymbolExpr(0) - b.getAffineDimExpr(0) &&
117 min_op.getDimOperands().front() == iv &&
118 min_op.getSymbolOperands().front() == scf_for.upperBound())
119 continue;
120 if (i == b.getAffineDimExpr(0) - b.getAffineDimExpr(1) &&
121 min_op.getDimOperands().drop_front().front() == iv &&
122 min_op.getDimOperands().front() == scf_for.upperBound())
123 continue;
124 if (auto idx_op = scf_for.upperBound().getDefiningOp<ConstantIndexOp>()) {
125 auto val = idx_op.getValue();
126 if (i == b.getAffineConstantExpr(val) - b.getAffineDimExpr(0) &&
127 min_op.getDimOperands().front() == iv)
128 continue;
129 }
130 return false;
131 }
132 return min_by_step;
133 };
134
135 // Determine if the loop should be split based on the existence of
136 // AffineMinOps of an expected form.
137 llvm::SmallVector<AffineMinOp, 1> min_ops;
138 scf_for->walk([&](AffineMinOp min_op) {
139 if (is_op_of_interest(min_op, scf_for.getInductionVar()))
140 min_ops.push_back(min_op);
141 });
142 if (min_ops.empty()) {
143 return;
144 }
145
146 // Split the loop just before a possible last iteration.
147 b.setInsertionPoint(scf_for);
148 Value split_point = b.create<SubIOp>(
149 scf_for.upperBound(),
150 b.create<UnsignedRemIOp>(
151 b.create<SubIOp>(scf_for.upperBound(), scf_for.lowerBound()),
152 scf_for.step()));
153
154 // New primary loop with relevant min ops replaced with their constant value
155 BlockAndValueMapping mapper;
156 auto new_loop = llvm::cast<scf::ForOp>(b.clone(*scf_for, mapper));
157 new_loop.setUpperBound(split_point);
158
159 new_loop->walk([&](AffineMinOp min_op) {
160 if (is_op_of_interest(min_op, new_loop.getInductionVar()))
161 min_op->replaceAllUsesWith(llvm::makeArrayRef(scf_for.step()));
162 });
163
164 // Peeled loop iteration (or nothing if perfectly aligned data and step sizes)
165 BlockAndValueMapping tail_mapper;
166 tail_mapper.map(scf_for.getRegionIterArgs(), new_loop.results());
167 tail_mapper.map(scf_for.getInductionVar(), split_point);
168 auto tail_if = b.create<scf::IfOp>(
169 scf_for.getResultTypes(),
170 b.create<CmpIOp>(CmpIPredicate::ult, split_point, scf_for.upperBound()),
171 [&](OpBuilder &then_b, Location loc) {
172 for (auto &op : *scf_for.getBody()) {
173 then_b.clone(op, tail_mapper);
174 }
175 }, scf_for->getNumResults() ?
176 [&](OpBuilder &else_b, Location loc) {
177 else_b.clone(scf_for.getBody()->back(), tail_mapper);
178 } : static_cast<function_ref<void(OpBuilder &, Location)>>(nullptr));
179
180 tail_if->walk([&](AffineMinOp min_op) {
181 SmallVector<AffineExpr> exprs;
182
183 if (!is_op_of_interest(min_op, split_point)) return;
184
185 ImplicitLocOpBuilder::InsertionGuard g(b);
186 b.setInsertionPoint(min_op);
187
188 // This function is to be called on comparisons that use the min_ops of
189 // interest in the last loop iteration. Through loop splitting, we know that
190 // the min result is strictly less than the step value. Therefore, we can
191 // take the predicate and a statement regarding the location of the min_op
192 // (and the implied position of the step value) to evaluate the cmpi.
193 auto is_true_cmp = [](CmpIPredicate pred, bool min_is_op_0) {
194 switch (pred) {
195 // This loop splitting guarantees the step is not equal to the min on
196 // the last iteration.
197 case CmpIPredicate::eq:
198 case CmpIPredicate::ne:
199 return false;
200 case CmpIPredicate::sle:
201 case CmpIPredicate::slt:
202 case CmpIPredicate::ule:
203 case CmpIPredicate::ult:
204 return min_is_op_0;
205 case CmpIPredicate::sge:
206 case CmpIPredicate::sgt:
207 case CmpIPredicate::uge:
208 case CmpIPredicate::ugt:
209 return !min_is_op_0;
210 }
211 };
212
213 for (auto user : min_op->getUsers()) {
214 if (auto cmp = dyn_cast<CmpIOp>(user)) {
215 if (cmp.getOperand(0) == min_op.getResult() &&
216 cmp.getOperand(1) == step_bound_op) {
217 cmp.replaceAllUsesWith(
218 b.create<ConstantIntOp>(is_true_cmp(cmp.predicate(), true), 1)
219 .getResult());
220 cmp.erase();
221 } else if (cmp.getOperand(0) == step_bound_op &&
222 cmp.getOperand(1) == min_op.getResult()) {
223 cmp.replaceAllUsesWith(
224 b.create<ConstantIntOp>(is_true_cmp(cmp.predicate(), false), 1)
225 .getResult());
226 }
227 }
228 }
229
230 // Replace the min_op with a simplified min_op that removes the constant
231 // step option. This will be further simplified after affine ops are
232 // lowered.
233 auto map = min_op.getAffineMap();
234 for (auto i : map.getResults()) {
235 if (i != b.getAffineConstantExpr(step_bound_value.getInt()))
236 exprs.push_back(i);
237 }
238
239 Value new_min = b.createOrFold<AffineMinOp>(
240 AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
241 b.getContext()),
242 min_op.operands());
243
244 min_op->replaceAllUsesWith(llvm::makeArrayRef(new_min));
245 });
246
247 scf_for->replaceAllUsesWith(tail_if.results());
248 scf_for.erase();
249 }
250
251 // A pass to remove memref::AllocOps and other ops interacting with the memrefs
252 // if it is provable that this will not change the results of the program. This
253 // is determined by confirming all consumers of all aliases are only creating an
254 // alias or writing data to an alias but never reading from or interacting with
255 // the memref in other ways.
RemoveDeadMemrefCode(FuncOp func)256 void RemoveDeadMemrefCode(FuncOp func) {
257 BufferViewFlowAnalysis baa(func);
258 llvm::SmallSet<Operation *, 8> to_remove;
259
260 // Gather all operations interacting with memrefs guaranteed to never be read
261 // from.
262 func->walk([&](memref::AllocaOp op) {
263 llvm::SmallVector<Operation *> maybe_to_remove;
264 for (auto &alias : baa.resolve(op.getResult())) {
265 for (auto user : alias.getUsers()) {
266 if (!(isa<ViewLikeOpInterface>(user) ||
267 (isa<linalg::CopyOp>(user) &&
268 alias == cast<linalg::CopyOp>(user).output()) ||
269 (isa<linalg::FillOp>(user) &&
270 alias == cast<linalg::FillOp>(user).output()))) {
271 return;
272 }
273 maybe_to_remove.push_back(user);
274 }
275 }
276 to_remove.insert(maybe_to_remove.begin(), maybe_to_remove.end());
277 to_remove.insert(op);
278 });
279
280 // Erase after the walk to avoid corrupting data being traversed.
281 for (auto *op : to_remove) {
282 op->dropAllUses();
283 op->erase();
284 }
285 }
286
287 struct VectorizationPass : public VectorizationPassBase<VectorizationPass> {
getDependentDialectsmlir::kernel_gen::transforms::__anon670ff0f00111::VectorizationPass288 void getDependentDialects(DialectRegistry ®istry) const override {
289 registry.insert<vector::VectorDialect, memref::MemRefDialect,
290 scf::SCFDialect>();
291 }
292
runOnFunctionmlir::kernel_gen::transforms::__anon670ff0f00111::VectorizationPass293 void runOnFunction() override {
294 // This functions in 2 passes:
295 // 1. Tile, promote, and vectorize to create elementwise operations on
296 // <(1x)*4xty> memrefs
297 // 2. cast <(1x)*4xty> memrefs to <4xty>
298 auto f = getFunction();
299
300 // Stage 1: Vectorize to form static shaped computations
301 auto tiling_options =
302 linalg::LinalgTilingOptions().setTileSizeComputationFunction(
303 [](OpBuilder b, Operation *op) {
304 auto num_loops = llvm::cast<linalg::LinalgOp>(op).getNumLoops();
305 SmallVector<Value> tiles(
306 num_loops, b.create<ConstantIndexOp>(op->getLoc(), 1));
307 if (!tiles.empty())
308 tiles.back() = b.create<ConstantIndexOp>(op->getLoc(), 4);
309 return tiles;
310 });
311 auto alignment = 16;
312 mlir::linalg::CodegenStrategy()
313 .tile<mlir::linalg::GenericOp>(tiling_options)
314 .promote<mlir::linalg::GenericOp>(
315 mlir::linalg::LinalgPromotionOptions()
316 .setAlignment(alignment)
317 .setUseFullTileBuffersByDefault(true)
318 .setUseAlloca(true))
319 .vectorize<mlir::linalg::GenericOp>()
320 .setVectorTransformsOptions(
321 mlir::vector::VectorTransformsOptions().setVectorTransferSplit(
322 mlir::vector::VectorTransferSplit::VectorTransfer))
323 .setVectorTransferToSCFOptions(
324 mlir::VectorTransferToSCFOptions().setUnroll(true))
325 .transform(f);
326
327 // Stage 2: Remove extent 1 dims to ensure correct 1-ranked vectorization
328 auto ctx = f.getContext();
329 OwningRewritePatternList patterns(ctx);
330 mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
331 (void)applyPatternsAndFoldGreedily(f, std::move(patterns));
332 }
333 };
334
335 } // namespace
336
CreateVectorizationPass()337 std::unique_ptr<FunctionPass> CreateVectorizationPass() {
338 return std::make_unique<VectorizationPass>();
339 }
340
341 struct VectorizationCleanupPass
342 : public VectorizationCleanupPassBase<VectorizationCleanupPass> {
getDependentDialectsmlir::kernel_gen::transforms::VectorizationCleanupPass343 void getDependentDialects(DialectRegistry ®istry) const override {
344 registry.insert<memref::MemRefDialect, scf::SCFDialect,
345 vector::VectorDialect>();
346 }
347
runOnFunctionmlir::kernel_gen::transforms::VectorizationCleanupPass348 void runOnFunction() override {
349 getFunction().walk([](scf::ForOp op) { SplitSCFForOp(op); });
350
351 RemoveDeadMemrefCode(getFunction());
352 }
353 };
354
CreateVectorizationCleanupPass()355 std::unique_ptr<FunctionPass> CreateVectorizationCleanupPass() {
356 return std::make_unique<VectorizationCleanupPass>();
357 }
358
359 } // namespace transforms
360 } // namespace kernel_gen
361 } // namespace mlir
362