1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements inline fusion.
17 //
18 #include <limits>
19
20 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/lhlo_elemental_utils.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
23 #include "mlir/Analysis/LoopAnalysis.h"
24 #include "mlir/Analysis/Utils.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34 #include "mlir/Transforms/LoopFusionUtils.h"
35 #include "mlir/Transforms/LoopUtils.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "mlir/Transforms/Utils.h"
38
39 using mlir::memref::LoadOp;
40
41 namespace mlir {
42 namespace lmhlo {
43
44 #define GEN_PASS_CLASSES
45 #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
46
47 namespace {
48
49 // TODO(disc): Maybe it worth explicitly adding the I/O buffers onto the
50 // outlining of lmhlo::FusionOp and then mark IsolatedFromAbove for
51 // lmhlo::FusionOp. By this way the fusion codegen passes can be OperationPass
52 // on lmhlo::FusionOp for better compilation overhead.
53 class InputInlineFusion : public InputInlineFusionPassBase<InputInlineFusion> {
54 void runOnFunction() override;
55 };
56
57 } // end anonymous namespace
58
createInputInlineFusionPass()59 std::unique_ptr<FunctionPass> createInputInlineFusionPass() {
60 return std::make_unique<InputInlineFusion>();
61 }
62
63 namespace {
64
65 constexpr unsigned c_MAX_ITERATION = 4096;
66
67 // This pass works after LhloLegalizeRootsToParallelLoops pass for the
68 // XLA-style fusion codegen.
69 //
70 // It iteratively looks for the lmhlo op which is the direct producer of the
71 // nested loops, and then inline fuse it if the fusion will not form a cycle.
72 //
73 // The inline fusion action can be generalized as:
74 // step 1: replace the producer Lhlo op into associate std op inside the nested
75 // loops. step 2: remove the original Load ops inside the loops and insert new
76 // Load ops.
77 //
78 // If there are multiple LoadOps with the same indices, they will be replaced
79 // with the same op. This obtains the similar result as GeneratedValueCache.
80 //
81 // IR after LhloLegalizeRootsToParallelLoops:
82 // "lmhlo.fusion"() ( {
83 // lmhlo.aaa(%0, %1, %2)
84 // lmhlo.bbb(%2, %3, %4)
85 // scf.parallel (...) {
86 // memref.load %4[...]
87 // ...
88 // memref.store ...
89 // }
90 // })
91 //
92 // IR after one round of InputInlineFusionPattern:
93 // "lmhlo.fusion"() ( {
94 // lmhlo.aaa(%0, %1, %2)
95 // scf.parallel (...) {
96 // memref.load %2[...]
97 // ...
98 // memref.store ...
99 // }
100 // })
101 //
102 // Final IR after this pass:
103 // "lmhlo.fusion"() ( {
104 // scf.parallel (...) {
105 // memref.load ...
106 // ...
107 // memref.store ...
108 // }
109 // })
110 class InputInlineFusionPattern : public RewritePattern {
111 public:
InputInlineFusionPattern(MLIRContext * context)112 explicit InputInlineFusionPattern(MLIRContext* context)
113 : RewritePattern(FusionOp::getOperationName(), 1, context) {}
114
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const115 LogicalResult matchAndRewrite(Operation* op,
116 PatternRewriter& rewriter) const override {
117 // skip if not the most outter ParallelOp
118 auto fusion = cast<FusionOp>(op);
119 auto& parent_block = fusion.region().front();
120 SmallVector<scf::ParallelOp, 4> parallel_ops;
121 fusion.walk([&](scf::ParallelOp parallel_op) {
122 parallel_ops.push_back(parallel_op);
123 });
124 assert(parallel_ops.size() == 1 &&
125 "only one scf::ParallelOp is expected after "
126 "LhloLegalizeRootsToParallelLoops");
127 scf::ParallelOp parallel_op = parallel_ops.front();
128 SmallVector<LoadOp, 4> load_ops;
129 parallel_op->walk([&](LoadOp load_op) { load_ops.push_back(load_op); });
130 for (auto load_op : load_ops) {
131 auto lhlo_op = getFusibleOperation(load_op);
132 if (!lhlo_op) continue;
133 // 1, in case of:
134 // A = ...
135 // B = op(A)
136 // C = op(A, B)
137 // C should fuse B first before fusing A.
138 // This is the same logic as in instruction_fusion pass of XLA
139 //
140 // 2, When multiple loads consumes the same result of lhlo_op and
141 // the load indices are also identical, the ir should be
142 // emitted only once. Other LoadOps should used cached Value.
143
144 // 'load_ops' that can consume the same cached value
145 SmallVector<LoadOp> same_load_ops;
146 bool can_remove_producer;
147 if (!checkIfFusible(parallel_op, lhlo_op, load_op, can_remove_producer,
148 same_load_ops))
149 continue;
150 // 'load_op' is always the one that locates in the most
151 // external code block among all the 'same_load_ops', because the walker
152 // is in the post order sequence.
153 if (failed(inlineFuseLhloOp(rewriter, parallel_op, lhlo_op, load_op,
154 same_load_ops)))
155 return failure();
156 if (can_remove_producer) rewriter.eraseOp(lhlo_op);
157 for (LoadOp to_be_removed : same_load_ops)
158 rewriter.eraseOp(to_be_removed);
159
160 // Clean all the ops that do not have LoadOps inside the nested
161 // ParallelOps and is not the ancestor of any ops that have LoadOps
162 // inside the nested ParallelOps.
163 cleanUnusedLhloOps(&parent_block);
164
165 return success();
166 }
167 return failure();
168 }
169
170 private:
171 Operation* getFusibleOperation(LoadOp load_op) const;
172 LogicalResult inlineFuseLhloOp(PatternRewriter& b, Operation* user,
173 Operation* producer, LoadOp load_op,
174 const SmallVector<LoadOp>& load_ops) const;
175 bool checkIfFusible(scf::ParallelOp user, Operation* producer, LoadOp load_op,
176 bool& can_remove_producer,
177 SmallVector<LoadOp>& load_ops) const;
178 };
179
getFusibleOperation(LoadOp load_op) const180 Operation* InputInlineFusionPattern::getFusibleOperation(LoadOp load_op) const {
181 Operation* lhlo_op = nullptr;
182 for (auto* user : load_op.getMemRef().getUsers()) {
183 if (isa<LmhloOp>(user) && (cast<LmhloOp>(user).getResultBuffer() ==
184 load_op.getOperation()->getOperand(0))) {
185 if (lhlo_op)
186 llvm::report_fatal_error(
187 "More than one lhlo_op write to one Memref within one fusion");
188 lhlo_op = user;
189 }
190 }
191 return lhlo_op;
192 }
193
194 // Check if there are no other consumers of the producer
195 // except the ParallelOp.
checkIfFusible(scf::ParallelOp user,Operation * producer,LoadOp load_op,bool & can_remove_producer,SmallVector<LoadOp> & load_ops) const196 bool InputInlineFusionPattern::checkIfFusible(
197 scf::ParallelOp user, Operation* producer, LoadOp load_op,
198 bool& can_remove_producer, SmallVector<LoadOp>& load_ops) const {
199 load_ops.clear();
200 assert(isa<LmhloOp>(producer) && "Unexpected producer in checkIfFusible");
201 auto producer_result_memref = cast<LmhloOp>(producer).getResultBuffer();
202 can_remove_producer = true;
203 auto lhlo_dialect = user->getContext()->getLoadedDialect("lmhlo");
204 for (auto* memref_user : producer_result_memref.getUsers()) {
205 if ((memref_user->getDialect() == lhlo_dialect) &&
206 (memref_user != producer)) {
207 return false;
208 }
209 LoadOp other = dyn_cast<LoadOp>(memref_user);
210 if (!other) continue;
211 if (other.getMemRef() == load_op.getMemRef() &&
212 other.getIndices() == load_op.getIndices())
213 load_ops.emplace_back(other);
214 else
215 can_remove_producer = false;
216 // TODO(disc): check the memref_user is inside the loops
217 }
218 return true;
219 }
220
221 template <typename LHLO_OpTy>
elemwiseFuseHelper(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)222 bool elemwiseFuseHelper(PatternRewriter& rewriter, Operation* user,
223 Operation* producer, LoadOp load_op,
224 const SmallVector<LoadOp>& load_ops) {
225 if (!isa<LHLO_OpTy>(producer) ||
226 !LHLO_OpTy::template hasTrait<OpTrait::Elementwise>())
227 return false;
228 auto loc = user->getLoc();
229 SmallVector<Value, 4> operand_values;
230 unsigned num_operands = producer->getNumOperands();
231 for (unsigned i = 0; i < num_operands - 1; ++i) {
232 auto producer_operand = producer->getOperand(i);
233 rewriter.setInsertionPoint(load_op);
234 operand_values.push_back(
235 rewriter.create<LoadOp>(loc, producer_operand, load_op.getIndices()));
236 }
237 auto inlined_result =
238 HloOpToStdScalarOp::map<LHLO_OpTy>(llvm::cast<LHLO_OpTy>(producer),
239 cast<LmhloOp>(producer)
240 .getResultBuffer()
241 .getType()
242 .cast<MemRefType>()
243 .getElementType(),
244 operand_values, &rewriter);
245
246 for (LoadOp to_be_replaced : load_ops)
247 to_be_replaced.replaceAllUsesWith(inlined_result);
248 return true;
249 }
250
251 template <typename LHLO_OpTy>
miscFuseHelper(PatternRewriter & rewriter,Operation * user,Operation * opaque_producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)252 bool miscFuseHelper(PatternRewriter& rewriter, Operation* user,
253 Operation* opaque_producer, LoadOp load_op,
254 const SmallVector<LoadOp>& load_ops) {
255 LHLO_OpTy producer = dyn_cast<LHLO_OpTy>(opaque_producer);
256 if (!producer) return false;
257 auto loc = user->getLoc();
258 rewriter.setInsertionPoint(load_op);
259 auto inlined_result =
260 elementalLower<LHLO_OpTy>(&rewriter, loc, producer, load_op.getIndices());
261 for (LoadOp to_be_replaced : load_ops)
262 to_be_replaced.replaceAllUsesWith(inlined_result);
263 return true;
264 }
265
266 template <>
miscFuseHelper(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)267 bool miscFuseHelper<ConstOp>(PatternRewriter& rewriter, Operation* user,
268 Operation* producer, LoadOp load_op,
269 const SmallVector<LoadOp>& load_ops) {
270 if (!isa<ConstOp>(producer)) return false;
271 auto memref_type =
272 cast<LmhloOp>(producer).getResultBuffer().getType().cast<MemRefType>();
273 assert(memref_type.getRank() == 0 && "only scalar ConstOp can be fused");
274 auto loc = user->getLoc();
275 rewriter.setInsertionPoint(load_op);
276 Value inlined_result =
277 rewriter.create<ConstantOp>(loc, memref_type.getElementType(),
278 cast<ConstOp>(producer).value().getValue({}));
279 for (LoadOp to_be_replaced : load_ops)
280 to_be_replaced.replaceAllUsesWith(inlined_result);
281 return true;
282 }
283
284 template <typename First>
elemwiseFuseHelperOr(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)285 bool elemwiseFuseHelperOr(PatternRewriter& rewriter, Operation* user,
286 Operation* producer, LoadOp load_op,
287 const SmallVector<LoadOp>& load_ops) {
288 return elemwiseFuseHelper<First>(rewriter, user, producer, load_op, load_ops);
289 }
290
291 template <typename First, typename Second, typename... Rest>
elemwiseFuseHelperOr(PatternRewriter & rewriter,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops)292 bool elemwiseFuseHelperOr(PatternRewriter& rewriter, Operation* user,
293 Operation* producer, LoadOp load_op,
294 const SmallVector<LoadOp>& load_ops) {
295 return elemwiseFuseHelperOr<First>(rewriter, user, producer, load_op,
296 load_ops) ||
297 elemwiseFuseHelperOr<Second, Rest...>(rewriter, user, producer,
298 load_op, load_ops);
299 }
300
301 // load_op is among the load_ops, whose locates in the most
302 // external code block
inlineFuseLhloOp(PatternRewriter & b,Operation * user,Operation * producer,LoadOp load_op,const SmallVector<LoadOp> & load_ops) const303 LogicalResult InputInlineFusionPattern::inlineFuseLhloOp(
304 PatternRewriter& b, Operation* user, Operation* producer, LoadOp load_op,
305 const SmallVector<LoadOp>& load_ops) const {
306 if (elemwiseFuseHelperOr<
307 #define GET_SUPPORTED_OP_LIST
308 #include "mlir-hlo/utils/disc_supported_list.h.inc"
309 >(b, user, producer, load_op, load_ops) ||
310 // TODO(disc): Upstream is on the way for more Ops
311 miscFuseHelper<RealDynamicSliceOp>(b, user, producer, load_op,
312 load_ops) ||
313 miscFuseHelper<DynamicBroadcastInDimOp>(b, user, producer, load_op,
314 load_ops) ||
315 miscFuseHelper<BroadcastInDimOp>(b, user, producer, load_op, load_ops) ||
316 miscFuseHelper<ConstOp>(b, user, producer, load_op, load_ops)) {
317 return success();
318 }
319
320 return failure();
321 }
322
runOnFunction()323 void InputInlineFusion::runOnFunction() {
324 auto func = getFunction();
325 auto* context = &this->getContext();
326 OwningRewritePatternList patterns(context);
327 patterns.insert<InputInlineFusionPattern>(context);
328
329 // Just apply the patterns greedily.
330 // There should always be one scf.ParallelOp in the fusion.
331 auto config = GreedyRewriteConfig();
332 config.maxIterations = c_MAX_ITERATION;
333 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config))) {
334 signalPassFailure();
335 }
336
337 // there should be no lmhlo ops after inline fusion,
338 // except for the ConstOp of ColReduction, which for now cannot be
339 // properly optimized by general DCE pass
340 std::vector<Operation*> to_be_removed;
341 func.walk([&](FusionOp fusion) {
342 fusion.region().walk([&](LmhloOp op) {
343 if (isa<TerminatorOp>(op)) {
344 return;
345 }
346 if (isa<ConstOp>(op)) {
347 // TODO(disc): Check the ConstOp is from ReduceOp
348 to_be_removed.push_back(op);
349 return;
350 }
351 op.emitError("unexpected remaining operation in a FusionOp");
352 signalPassFailure();
353 });
354 });
355 for (auto op : to_be_removed) {
356 op->erase();
357 }
358 }
359
360 } // namespace
361
362 } // namespace lmhlo
363 } // namespace mlir
364