• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements inline fusion.
17 //
18 #include <limits>
19 
20 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/lhlo_elemental_utils.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
23 #include "mlir/Analysis/LoopAnalysis.h"
24 #include "mlir/Analysis/Utils.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 #include "mlir/Transforms/LoopFusionUtils.h"
35 #include "mlir/Transforms/LoopUtils.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "mlir/Transforms/Utils.h"
38 
39 using mlir::memref::LoadOp;
40 
41 namespace mlir {
42 namespace lmhlo {
43 
44 #define GEN_PASS_CLASSES
45 #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
46 
47 namespace {
48 
49 // TODO(disc): Maybe it worth explicitly adding the I/O buffers onto the
50 // outlining of lmhlo::FusionOp and then mark IsolatedFromAbove for
51 // lmhlo::FusionOp. By this way the fusion codegen passes can be OperationPass
52 // on lmhlo::FusionOp for better compilation overhead.
53 class InputInlineFusion : public InputInlineFusionPassBase<InputInlineFusion> {
54   void runOnFunction() override;
55 };
56 
57 }  // end anonymous namespace
58 
createInputInlineFusionPass()59 std::unique_ptr<FunctionPass> createInputInlineFusionPass() {
60   return std::make_unique<InputInlineFusion>();
61 }
62 
63 namespace {
64 
65 constexpr unsigned c_MAX_ITERATION = 4096;
66 
67 // This pass works after LhloLegalizeRootsToParallelLoops pass for the
68 // XLA-style fusion codegen.
69 //
70 // It iteratively looks for the lmhlo op which is the direct producer of the
71 // nested loops, and then inline fuse it if the fusion will not form a cycle.
72 //
73 // The inline fusion action can be generalized as:
74 // step 1: replace the producer Lhlo op into associate std op inside the nested
75 // loops. step 2: remove the original Load ops inside the loops and insert new
76 // Load ops.
77 //
78 // If there are multiple LoadOps with the same indices, they will be replaced
79 // with the same op. This obtains the similar result as GeneratedValueCache.
80 //
81 // IR after LhloLegalizeRootsToParallelLoops:
82 //    "lmhlo.fusion"() ( {
83 //       lmhlo.aaa(%0, %1, %2)
84 //       lmhlo.bbb(%2, %3, %4)
85 //       scf.parallel (...) {
86 //          memref.load %4[...]
87 //          ...
88 //          memref.store ...
89 //       }
90 //    })
91 //
92 // IR after one round of InputInlineFusionPattern:
93 //    "lmhlo.fusion"() ( {
94 //       lmhlo.aaa(%0, %1, %2)
95 //       scf.parallel (...) {
96 //          memref.load %2[...]
97 //          ...
98 //          memref.store ...
99 //       }
100 //    })
101 //
102 // Final IR after this pass:
103 //    "lmhlo.fusion"() ( {
104 //       scf.parallel (...) {
105 //          memref.load ...
106 //          ...
107 //          memref.store ...
108 //       }
109 //    })
110 class InputInlineFusionPattern : public RewritePattern {
111  public:
InputInlineFusionPattern(MLIRContext * context)112   explicit InputInlineFusionPattern(MLIRContext* context)
113       : RewritePattern(FusionOp::getOperationName(), 1, context) {}
114 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const115   LogicalResult matchAndRewrite(Operation* op,
116                                 PatternRewriter& rewriter) const override {
117     // skip if not the most outter ParallelOp
118     auto fusion = cast<FusionOp>(op);
119     auto& parent_block = fusion.region().front();
120     SmallVector<scf::ParallelOp, 4> parallel_ops;
121     fusion.walk([&](scf::ParallelOp parallel_op) {
122       parallel_ops.push_back(parallel_op);
123     });
124     assert(parallel_ops.size() == 1 &&
125            "only one scf::ParallelOp is expected after "
126            "LhloLegalizeRootsToParallelLoops");
127     scf::ParallelOp parallel_op = parallel_ops.front();
128     SmallVector<LoadOp, 4> load_ops;
129     parallel_op->walk([&](LoadOp load_op) { load_ops.push_back(load_op); });
130     for (auto load_op : load_ops) {
131       auto lhlo_op = getFusibleOperation(load_op);
132       if (!lhlo_op) continue;
133       // 1, in case of:
134       //      A = ...
135       //      B = op(A)
136       //      C = op(A, B)
137       //    C should fuse B first before fusing A.
138       //    This is the same logic as in instruction_fusion pass of XLA
139       //
140       // 2, When multiple loads consumes the same result of lhlo_op and
141       //    the load indices are also identical, the ir should be
142       //    emitted only once. Other LoadOps should used cached Value.
143 
144       // 'load_ops' that can consume the same cached value
145       SmallVector<LoadOp> same_load_ops;
146       bool can_remove_producer;
147       if (!checkIfFusible(parallel_op, lhlo_op, load_op, can_remove_producer,
148                           same_load_ops))
149         continue;
150       // 'load_op' is always the one that locates in the most
151       // external code block among all the 'same_load_ops', because the walker
152       // is in the post order sequence.
153       if (failed(inlineFuseLhloOp(rewriter, parallel_op, lhlo_op, load_op,
154                                   same_load_ops)))
155         return failure();
156       if (can_remove_producer) rewriter.eraseOp(lhlo_op);
157       for (LoadOp to_be_removed : same_load_ops)
158         rewriter.eraseOp(to_be_removed);
159 
160       // Clean all the ops that do not have LoadOps inside the nested
161       // ParallelOps and is not the ancestor of any ops that have LoadOps
162       // inside the nested ParallelOps.
163       cleanUnusedLhloOps(&parent_block);
164 
165       return success();
166     }
167     return failure();
168   }
169 
170  private:
171   Operation* getFusibleOperation(LoadOp load_op) const;
172   LogicalResult inlineFuseLhloOp(PatternRewriter& b, Operation* user,
173                                  Operation* producer, LoadOp load_op,
174                                  const SmallVector<LoadOp>& load_ops) const;
175   bool checkIfFusible(scf::ParallelOp user, Operation* producer, LoadOp load_op,
176                       bool& can_remove_producer,
177                       SmallVector<LoadOp>& load_ops) const;
178 };
179 
getFusibleOperation(LoadOp load_op) const180 Operation* InputInlineFusionPattern::getFusibleOperation(LoadOp load_op) const {
181   Operation* lhlo_op = nullptr;
182   for (auto* user : load_op.getMemRef().getUsers()) {
183     if (isa<LmhloOp>(user) && (cast<LmhloOp>(user).getResultBuffer() ==
184                                load_op.getOperation()->getOperand(0))) {
185       if (lhlo_op)
186         llvm::report_fatal_error(
187             "More than one lhlo_op write to one Memref within one fusion");
188       lhlo_op = user;
189     }
190   }
191   return lhlo_op;
192 }
193 
194 // Check if there are no other consumers of the producer
195 // except the ParallelOp.
checkIfFusible(scf::ParallelOp user,Operation * producer,LoadOp load_op,bool & can_remove_producer,SmallVector<LoadOp> & load_ops) const196 bool InputInlineFusionPattern::checkIfFusible(
197     scf::ParallelOp user, Operation* producer, LoadOp load_op,
198     bool& can_remove_producer, SmallVector<LoadOp>& load_ops) const {
199   load_ops.clear();
200   assert(isa<LmhloOp>(producer) && "Unexpected producer in checkIfFusible");
201   auto producer_result_memref = cast<LmhloOp>(producer).getResultBuffer();
202   can_remove_producer = true;
203   auto lhlo_dialect = user->getContext()->getLoadedDialect("lmhlo");
204   for (auto* memref_user : producer_result_memref.getUsers()) {
205     if ((memref_user->getDialect() == lhlo_dialect) &&
206         (memref_user != producer)) {
207       return false;
208     }
209     LoadOp other = dyn_cast<LoadOp>(memref_user);
210     if (!other) continue;
211     if (other.getMemRef() == load_op.getMemRef() &&
212         other.getIndices() == load_op.getIndices())
213       load_ops.emplace_back(other);
214     else
215       can_remove_producer = false;
216     // TODO(disc): check the memref_user is inside the loops
217   }
218   return true;
219 }
220 
221 template <typename LHLO_OpTy>
elemwiseFuseHelper(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)222 bool elemwiseFuseHelper(PatternRewriter& rewriter, Operation* user,
223                         Operation* producer, LoadOp load_op,
224                         const SmallVector<LoadOp>& load_ops) {
225   if (!isa<LHLO_OpTy>(producer) ||
226       !LHLO_OpTy::template hasTrait<OpTrait::Elementwise>())
227     return false;
228   auto loc = user->getLoc();
229   SmallVector<Value, 4> operand_values;
230   unsigned num_operands = producer->getNumOperands();
231   for (unsigned i = 0; i < num_operands - 1; ++i) {
232     auto producer_operand = producer->getOperand(i);
233     rewriter.setInsertionPoint(load_op);
234     operand_values.push_back(
235         rewriter.create<LoadOp>(loc, producer_operand, load_op.getIndices()));
236   }
237   auto inlined_result =
238       HloOpToStdScalarOp::map<LHLO_OpTy>(llvm::cast<LHLO_OpTy>(producer),
239                                          cast<LmhloOp>(producer)
240                                              .getResultBuffer()
241                                              .getType()
242                                              .cast<MemRefType>()
243                                              .getElementType(),
244                                          operand_values, &rewriter);
245 
246   for (LoadOp to_be_replaced : load_ops)
247     to_be_replaced.replaceAllUsesWith(inlined_result);
248   return true;
249 }
250 
251 template <typename LHLO_OpTy>
miscFuseHelper(PatternRewriter & rewriter,Operation * user,Operation * opaque_producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)252 bool miscFuseHelper(PatternRewriter& rewriter, Operation* user,
253                     Operation* opaque_producer, LoadOp load_op,
254                     const SmallVector<LoadOp>& load_ops) {
255   LHLO_OpTy producer = dyn_cast<LHLO_OpTy>(opaque_producer);
256   if (!producer) return false;
257   auto loc = user->getLoc();
258   rewriter.setInsertionPoint(load_op);
259   auto inlined_result =
260       elementalLower<LHLO_OpTy>(&rewriter, loc, producer, load_op.getIndices());
261   for (LoadOp to_be_replaced : load_ops)
262     to_be_replaced.replaceAllUsesWith(inlined_result);
263   return true;
264 }
265 
266 template <>
miscFuseHelper(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)267 bool miscFuseHelper<ConstOp>(PatternRewriter& rewriter, Operation* user,
268                              Operation* producer, LoadOp load_op,
269                              const SmallVector<LoadOp>& load_ops) {
270   if (!isa<ConstOp>(producer)) return false;
271   auto memref_type =
272       cast<LmhloOp>(producer).getResultBuffer().getType().cast<MemRefType>();
273   assert(memref_type.getRank() == 0 && "only scalar ConstOp can be fused");
274   auto loc = user->getLoc();
275   rewriter.setInsertionPoint(load_op);
276   Value inlined_result =
277       rewriter.create<ConstantOp>(loc, memref_type.getElementType(),
278                                   cast<ConstOp>(producer).value().getValue({}));
279   for (LoadOp to_be_replaced : load_ops)
280     to_be_replaced.replaceAllUsesWith(inlined_result);
281   return true;
282 }
283 
284 template <typename First>
elemwiseFuseHelperOr(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)285 bool elemwiseFuseHelperOr(PatternRewriter& rewriter, Operation* user,
286                           Operation* producer, LoadOp load_op,
287                           const SmallVector<LoadOp>& load_ops) {
288   return elemwiseFuseHelper<First>(rewriter, user, producer, load_op, load_ops);
289 }
290 
291 template <typename First, typename Second, typename... Rest>
elemwiseFuseHelperOr(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)292 bool elemwiseFuseHelperOr(PatternRewriter& rewriter, Operation* user,
293                           Operation* producer, LoadOp load_op,
294                           const SmallVector<LoadOp>& load_ops) {
295   return elemwiseFuseHelperOr<First>(rewriter, user, producer, load_op,
296                                      load_ops) ||
297          elemwiseFuseHelperOr<Second, Rest...>(rewriter, user, producer,
298                                                load_op, load_ops);
299 }
300 
301 // load_op is among the load_ops, whose locates in the most
302 // external code block
inlineFuseLhloOp(PatternRewriter & b,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops) const303 LogicalResult InputInlineFusionPattern::inlineFuseLhloOp(
304     PatternRewriter& b, Operation* user, Operation* producer, LoadOp load_op,
305     const SmallVector<LoadOp>& load_ops) const {
306   if (elemwiseFuseHelperOr<
307 #define GET_SUPPORTED_OP_LIST
308 #include "mlir-hlo/utils/disc_supported_list.h.inc"
309           >(b, user, producer, load_op, load_ops) ||
310       // TODO(disc): Upstream is on the way for more Ops
311       miscFuseHelper<RealDynamicSliceOp>(b, user, producer, load_op,
312                                          load_ops) ||
313       miscFuseHelper<DynamicBroadcastInDimOp>(b, user, producer, load_op,
314                                               load_ops) ||
315       miscFuseHelper<BroadcastInDimOp>(b, user, producer, load_op, load_ops) ||
316       miscFuseHelper<ConstOp>(b, user, producer, load_op, load_ops)) {
317     return success();
318   }
319 
320   return failure();
321 }
322 
runOnFunction()323 void InputInlineFusion::runOnFunction() {
324   auto func = getFunction();
325   auto* context = &this->getContext();
326   OwningRewritePatternList patterns(context);
327   patterns.insert<InputInlineFusionPattern>(context);
328 
329   // Just apply the patterns greedily.
330   // There should always be one scf.ParallelOp in the fusion.
331   auto config = GreedyRewriteConfig();
332   config.maxIterations = c_MAX_ITERATION;
333   if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config))) {
334     signalPassFailure();
335   }
336 
337   // there should be no lmhlo ops after inline fusion,
338   // except for the ConstOp of ColReduction, which for now cannot be
339   // properly optimized by general DCE pass
340   std::vector<Operation*> to_be_removed;
341   func.walk([&](FusionOp fusion) {
342     fusion.region().walk([&](LmhloOp op) {
343       if (isa<TerminatorOp>(op)) {
344         return;
345       }
346       if (isa<ConstOp>(op)) {
347         // TODO(disc): Check the ConstOp is from ReduceOp
348         to_be_removed.push_back(op);
349         return;
350       }
351       op.emitError("unexpected remaining operation in a FusionOp");
352       signalPassFailure();
353     });
354   });
355   for (auto op : to_be_removed) {
356     op->erase();
357   }
358 }
359 
360 }  // namespace
361 
362 }  // namespace lmhlo
363 }  // namespace mlir
364