• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "mlir/Analysis/BufferViewFlowAnalysis.h"  // from @llvm-project
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
22 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // from @llvm-project
23 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"  // from @llvm-project
24 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"  // from @llvm-project
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
26 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
27 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/IR/AffineMap.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
34 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
35 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
37 
38 namespace mlir {
39 namespace kernel_gen {
40 namespace transforms {
41 namespace {
42 #define GEN_PASS_CLASSES
43 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
44 
45 // This function takes ForOps that contain AffineMinOps and possibly peels off
46 // the last iteration of the loop. This is done in cases where it is provable
47 // that the AffineMinOp is deterministic in all cases except the possible last
48 // iteration. Some additional cleanup is done to simplify the IR that is correct
49 // through knowledge of what this transformation is doing but would generally be
50 // unwieldy in a canonicalization-like pattern.
51 //
52 // This pass is only necessary due to inefficiencies in VectorTransferSplit that
53 // is unlikely to be fixed upstream. If that changes, this pass can be fully
54 // removed.
55 //
56 // Example:
57 // scf.for %i = %c0 to %c11 step %c2
58 //   %a = affine.min(%c2, %c11-%i)
59 //
60 // Becomes:
61 // scf.for %i = %c0 to %c10 step %c2
62 //   %a = %c2
63 // scf.if %one_more_iter
64 //   %a = affine.min(2, %c11-%i)
65 //
66 // This is possible because we can determine that the min will always be 2
67 // except for the last iteration.
SplitSCFForOp(scf::ForOp scf_for)68 void SplitSCFForOp(scf::ForOp scf_for) {
69   // The set of following steps is:
70   // 1. Validate that there are min_ops to be modified in this function.
71   // 2. Create the boundary that decides whether the min_op evaluates to the
72   // loop's step value or to the computed value based upon the iteration value.
73   // 3. Create the primary loop that does all the work except for possibly the
74   // last iteration of the loop, and replace all relevant min_ops with the step.
75   // 4. Create the final iteration, remove the step from relevant min_ops, and
76   // additionally modify related if/else ops to have a constant condition based
77   // on what we know about this loop structure.
78 
79   // Match only when the lower bound is zero and the step is constant.
80   // TODO(TPOPP): Requiring constant steps and lower bound simplifies things
81   // but isn't necesarilly needed
82   auto lower_bound_op =
83       llvm::dyn_cast<ConstantOp>(scf_for.lowerBound().getDefiningOp());
84   if (!lower_bound_op) {
85     return;
86   }
87   auto lower_bound_value = lower_bound_op.getValue().dyn_cast<IntegerAttr>();
88   if (!lower_bound_value || lower_bound_value.getInt() != 0) {
89     return;
90   }
91 
92   auto step_bound_op =
93       llvm::dyn_cast<ConstantOp>(scf_for.step().getDefiningOp());
94   if (!step_bound_op) {
95     return;
96   }
97   auto step_bound_value = step_bound_op.getValue().dyn_cast<IntegerAttr>();
98   if (!step_bound_value) {
99     return;
100   }
101 
102   auto loc = scf_for.getLoc();
103   ImplicitLocOpBuilder b(loc, scf_for);
104 
105   // This function will determine if the min_op is an operation that can be
106   // transformed after loop splitting. This relies on the function that the op
107   // represents relative to the induction variable in its loop and the
108   // bounds of the original for loop.
109   auto is_op_of_interest = [&](AffineMinOp min_op, Value iv) {
110     bool min_by_step = false;
111     for (auto i : min_op.getAffineMap().getResults()) {
112       if (i == b.getAffineConstantExpr(step_bound_value.getInt())) {
113         min_by_step = true;
114         continue;
115       }
116       if (i == b.getAffineSymbolExpr(0) - b.getAffineDimExpr(0) &&
117           min_op.getDimOperands().front() == iv &&
118           min_op.getSymbolOperands().front() == scf_for.upperBound())
119         continue;
120       if (i == b.getAffineDimExpr(0) - b.getAffineDimExpr(1) &&
121           min_op.getDimOperands().drop_front().front() == iv &&
122           min_op.getDimOperands().front() == scf_for.upperBound())
123         continue;
124       if (auto idx_op = scf_for.upperBound().getDefiningOp<ConstantIndexOp>()) {
125         auto val = idx_op.getValue();
126         if (i == b.getAffineConstantExpr(val) - b.getAffineDimExpr(0) &&
127             min_op.getDimOperands().front() == iv)
128           continue;
129       }
130       return false;
131     }
132     return min_by_step;
133   };
134 
135   // Determine if the loop should be split based on the existence of
136   // AffineMinOps of an expected form.
137   llvm::SmallVector<AffineMinOp, 1> min_ops;
138   scf_for->walk([&](AffineMinOp min_op) {
139     if (is_op_of_interest(min_op, scf_for.getInductionVar()))
140       min_ops.push_back(min_op);
141   });
142   if (min_ops.empty()) {
143     return;
144   }
145 
146   // Split the loop just before a possible last iteration.
147   b.setInsertionPoint(scf_for);
148   Value split_point = b.create<SubIOp>(
149       scf_for.upperBound(),
150       b.create<UnsignedRemIOp>(
151           b.create<SubIOp>(scf_for.upperBound(), scf_for.lowerBound()),
152           scf_for.step()));
153 
154   // New primary loop with relevant min ops replaced with their constant value
155   BlockAndValueMapping mapper;
156   auto new_loop = llvm::cast<scf::ForOp>(b.clone(*scf_for, mapper));
157   new_loop.setUpperBound(split_point);
158 
159   new_loop->walk([&](AffineMinOp min_op) {
160     if (is_op_of_interest(min_op, new_loop.getInductionVar()))
161       min_op->replaceAllUsesWith(llvm::makeArrayRef(scf_for.step()));
162   });
163 
164   // Peeled loop iteration (or nothing if perfectly aligned data and step sizes)
165   BlockAndValueMapping tail_mapper;
166   tail_mapper.map(scf_for.getRegionIterArgs(), new_loop.results());
167   tail_mapper.map(scf_for.getInductionVar(), split_point);
168   auto tail_if = b.create<scf::IfOp>(
169       scf_for.getResultTypes(),
170       b.create<CmpIOp>(CmpIPredicate::ult, split_point, scf_for.upperBound()),
171       [&](OpBuilder &then_b, Location loc) {
172         for (auto &op : *scf_for.getBody()) {
173           then_b.clone(op, tail_mapper);
174         }
175       }, scf_for->getNumResults() ?
176       [&](OpBuilder &else_b, Location loc) {
177         else_b.clone(scf_for.getBody()->back(), tail_mapper);
178       } : static_cast<function_ref<void(OpBuilder &, Location)>>(nullptr));
179 
180   tail_if->walk([&](AffineMinOp min_op) {
181     SmallVector<AffineExpr> exprs;
182 
183     if (!is_op_of_interest(min_op, split_point)) return;
184 
185     ImplicitLocOpBuilder::InsertionGuard g(b);
186     b.setInsertionPoint(min_op);
187 
188     // This function is to be called on comparisons that use the min_ops of
189     // interest in the last loop iteration. Through loop splitting, we know that
190     // the min result is strictly less than the step value. Therefore, we can
191     // take the predicate and a statement regarding the location of the min_op
192     // (and the implied position of the step value) to evaluate the cmpi.
193     auto is_true_cmp = [](CmpIPredicate pred, bool min_is_op_0) {
194       switch (pred) {
195         // This loop splitting guarantees the step is not equal to the min on
196         // the last iteration.
197         case CmpIPredicate::eq:
198         case CmpIPredicate::ne:
199           return false;
200         case CmpIPredicate::sle:
201         case CmpIPredicate::slt:
202         case CmpIPredicate::ule:
203         case CmpIPredicate::ult:
204           return min_is_op_0;
205         case CmpIPredicate::sge:
206         case CmpIPredicate::sgt:
207         case CmpIPredicate::uge:
208         case CmpIPredicate::ugt:
209           return !min_is_op_0;
210       }
211     };
212 
213     for (auto user : min_op->getUsers()) {
214       if (auto cmp = dyn_cast<CmpIOp>(user)) {
215         if (cmp.getOperand(0) == min_op.getResult() &&
216             cmp.getOperand(1) == step_bound_op) {
217           cmp.replaceAllUsesWith(
218               b.create<ConstantIntOp>(is_true_cmp(cmp.predicate(), true), 1)
219                   .getResult());
220           cmp.erase();
221         } else if (cmp.getOperand(0) == step_bound_op &&
222                    cmp.getOperand(1) == min_op.getResult()) {
223           cmp.replaceAllUsesWith(
224               b.create<ConstantIntOp>(is_true_cmp(cmp.predicate(), false), 1)
225                   .getResult());
226         }
227       }
228     }
229 
230     // Replace the min_op with a simplified min_op that removes the constant
231     // step option. This will be further simplified after affine ops are
232     // lowered.
233     auto map = min_op.getAffineMap();
234     for (auto i : map.getResults()) {
235       if (i != b.getAffineConstantExpr(step_bound_value.getInt()))
236         exprs.push_back(i);
237     }
238 
239     Value new_min = b.createOrFold<AffineMinOp>(
240         AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
241                        b.getContext()),
242         min_op.operands());
243 
244     min_op->replaceAllUsesWith(llvm::makeArrayRef(new_min));
245   });
246 
247   scf_for->replaceAllUsesWith(tail_if.results());
248   scf_for.erase();
249 }
250 
251 // A pass to remove memref::AllocOps and other ops interacting with the memrefs
252 // if it is provable that this will not change the results of the program. This
253 // is determined by confirming all consumers of all aliases are only creating an
254 // alias or writing data to an alias but never reading from or interacting with
255 // the memref in other ways.
RemoveDeadMemrefCode(FuncOp func)256 void RemoveDeadMemrefCode(FuncOp func) {
257   BufferViewFlowAnalysis baa(func);
258   llvm::SmallSet<Operation *, 8> to_remove;
259 
260   // Gather all operations interacting with memrefs guaranteed to never be read
261   // from.
262   func->walk([&](memref::AllocaOp op) {
263     llvm::SmallVector<Operation *> maybe_to_remove;
264     for (auto &alias : baa.resolve(op.getResult())) {
265       for (auto user : alias.getUsers()) {
266         if (!(isa<ViewLikeOpInterface>(user) ||
267               (isa<linalg::CopyOp>(user) &&
268                alias == cast<linalg::CopyOp>(user).output()) ||
269               (isa<linalg::FillOp>(user) &&
270                alias == cast<linalg::FillOp>(user).output()))) {
271           return;
272         }
273         maybe_to_remove.push_back(user);
274       }
275     }
276     to_remove.insert(maybe_to_remove.begin(), maybe_to_remove.end());
277     to_remove.insert(op);
278   });
279 
280   // Erase after the walk to avoid corrupting data being traversed.
281   for (auto *op : to_remove) {
282     op->dropAllUses();
283     op->erase();
284   }
285 }
286 
287 struct VectorizationPass : public VectorizationPassBase<VectorizationPass> {
getDependentDialectsmlir::kernel_gen::transforms::__anon670ff0f00111::VectorizationPass288   void getDependentDialects(DialectRegistry &registry) const override {
289     registry.insert<vector::VectorDialect, memref::MemRefDialect,
290                     scf::SCFDialect>();
291   }
292 
runOnFunctionmlir::kernel_gen::transforms::__anon670ff0f00111::VectorizationPass293   void runOnFunction() override {
294     // This functions in 2 passes:
295     // 1. Tile, promote, and vectorize to create elementwise operations on
296     //    <(1x)*4xty> memrefs
297     // 2. cast <(1x)*4xty> memrefs to <4xty>
298     auto f = getFunction();
299 
300     // Stage 1: Vectorize to form static shaped computations
301     auto tiling_options =
302         linalg::LinalgTilingOptions().setTileSizeComputationFunction(
303             [](OpBuilder b, Operation *op) {
304               auto num_loops = llvm::cast<linalg::LinalgOp>(op).getNumLoops();
305               SmallVector<Value> tiles(
306                   num_loops, b.create<ConstantIndexOp>(op->getLoc(), 1));
307               if (!tiles.empty())
308                 tiles.back() = b.create<ConstantIndexOp>(op->getLoc(), 4);
309               return tiles;
310             });
311     auto alignment = 16;
312     mlir::linalg::CodegenStrategy()
313         .tile<mlir::linalg::GenericOp>(tiling_options)
314         .promote<mlir::linalg::GenericOp>(
315             mlir::linalg::LinalgPromotionOptions()
316                 .setAlignment(alignment)
317                 .setUseFullTileBuffersByDefault(true)
318                 .setUseAlloca(true))
319         .vectorize<mlir::linalg::GenericOp>()
320         .setVectorTransformsOptions(
321             mlir::vector::VectorTransformsOptions().setVectorTransferSplit(
322                 mlir::vector::VectorTransferSplit::VectorTransfer))
323         .setVectorTransferToSCFOptions(
324             mlir::VectorTransferToSCFOptions().setUnroll(true))
325         .transform(f);
326 
327     // Stage 2: Remove extent 1 dims to ensure correct 1-ranked vectorization
328     auto ctx = f.getContext();
329     OwningRewritePatternList patterns(ctx);
330     mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
331     (void)applyPatternsAndFoldGreedily(f, std::move(patterns));
332   }
333 };
334 
335 }  // namespace
336 
CreateVectorizationPass()337 std::unique_ptr<FunctionPass> CreateVectorizationPass() {
338   return std::make_unique<VectorizationPass>();
339 }
340 
341 struct VectorizationCleanupPass
342     : public VectorizationCleanupPassBase<VectorizationCleanupPass> {
getDependentDialectsmlir::kernel_gen::transforms::VectorizationCleanupPass343   void getDependentDialects(DialectRegistry &registry) const override {
344     registry.insert<memref::MemRefDialect, scf::SCFDialect,
345                     vector::VectorDialect>();
346   }
347 
runOnFunctionmlir::kernel_gen::transforms::VectorizationCleanupPass348   void runOnFunction() override {
349     getFunction().walk([](scf::ForOp op) { SplitSCFForOp(op); });
350 
351     RemoveDeadMemrefCode(getFunction());
352   }
353 };
354 
CreateVectorizationCleanupPass()355 std::unique_ptr<FunctionPass> CreateVectorizationCleanupPass() {
356   return std::make_unique<VectorizationCleanupPass>();
357 }
358 
359 }  // namespace transforms
360 }  // namespace kernel_gen
361 }  // namespace mlir
362